-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Exception when using threads #9424
Comments
I believe this is happening because you are using the default graph in each thread. Try creating a graph in the thread and using that instead of a default graph. |
If I create a graph for each thread, simply by using
on line I think this may be due to the fact that the networks are created outside of the threads, using a different graph. However if I try creating them inside threads, each in a different graph, i.e.
The following exception is raised, again on line
Also, I realize 64 is an excessive number, that value is only needed to reproduce the issue I experience in my real code, where I have 16 theads, each with 6 networks, of which one is selected based on some condition and used to make a prediction at every iteration. |
The exact use case is confusing me a little. I'm not understanding the purpose behind setting weights on a global network from within the threads instantaneously. I modified your code - I may not have carried over the functionality you were after regarding this global network but this is training one global model and then using it's values to predict in threads. import threading
import numpy as np
from keras.layers import Input, LSTM
from keras.models import Model, model_from_json
import tensorflow as tf
n_threads = 64
def make_rnet():
inp = Input(batch_shape=(1,1,5))
out = LSTM(1)(inp)
return Model(
inputs=inp,
outputs=out
)
def thread_fn(index, architecture, weights):
"""
:param index:
:type global_net: Model
"""
print("thread-%s" % index)
with tf.Session(graph = tf.Graph()) as sess:
# Build model.
net = model_from_json(architecture)
net.set_weights(weights)
net.compile(optimizer='rmsprop', loss='mse')
in_shape = [int(d) for d in net.input.shape]
out_shape = [int(d) for d in net.output.shape]
# Test prediction.
predictions = net.predict(np.ones(shape=in_shape))
print(predictions)
# Test fit on random data.
x = np.random.random(size=in_shape)
y = np.ones(shape=out_shape)
net.fit(x,y, verbose=0, batch_size=1)
# Train a global network.
global_net = make_rnet()
global_net.compile(optimizer='rmsprop', loss='mse')
x = np.random.random(
size=[int(d) for d in global_net.input.shape]
)
y = np.ones(
shape=[int(d) for d in global_net.output.shape]
)
global_net.fit(x, y)
# Get the network in a portable format.
architecture = global_net.to_json()
weights = global_net.get_weights()
# Generate threads.
threads = [
threading.Thread(target=thread_fn, args=(i, architecture, weights)) for i in range(n_threads)
]
print("starting %s threads..." % n_threads)
for t in threads:
t.start()
for t in threads:
t.join()
print("threads terminated.") Alternatively, if you do need to get data from the threads back out perhaps a queue implementation would work? import json
import threading
import queue
import numpy as np
from keras.layers import Input, LSTM
from keras.models import Model, model_from_json
import tensorflow as tf
queue = queue.Queue()
n_threads = 32
def make_rnet():
inp = Input(batch_shape=(1,1,5))
out = LSTM(1)(inp)
return Model(
inputs=inp,
outputs=out
)
def thread_fn(index, queue, architecture, weights):
"""
:param index:
:type global_net: Model
"""
print("thread-%s" % index)
with tf.Session(graph = tf.Graph()) as sess:
# Build model.
net = model_from_json(architecture)
net.set_weights(weights)
net.compile(optimizer='rmsprop', loss='mse')
in_shape = [int(d) for d in net.input.shape]
out_shape = [int(d) for d in net.output.shape]
# Test prediction.
predictions = net.predict(np.ones(shape=in_shape))
print(predictions)
# Test fit on random data.
x = np.random.random(size=in_shape)
y = np.ones(shape=out_shape)
net.fit(x,y, verbose=0, batch_size=1)
# Enqueue new weights.
queue.put({
'weights': net.get_weights(),
'index': index
})
# Train a global network.
global_net = make_rnet()
global_net.compile(optimizer='rmsprop', loss='mse')
x = np.random.random(
size=[int(d) for d in global_net.input.shape]
)
y = np.ones(
shape=[int(d) for d in global_net.output.shape]
)
global_net.fit(x, y)
# Get the network in a portable format.
architecture = global_net.to_json()
weights = global_net.get_weights()
# Generate threads.
threads = [
threading.Thread(target=thread_fn, args=(i, queue, architecture, weights)) for i in range(n_threads)
]
print("starting %s threads..." % n_threads)
for t in threads:
t.start()
while True:
# Setting weights from thread.
item = queue.get()
print("Callback from thread-%s" % item['index'])
global_net.set_weights(item['weights'])
queue.task_done() |
Thank you for your help. I like your queue implementation, I was also thinking I don't really need a global network, only global weights to syncronize the threads to. I'm not sure I understand why keras is not throwing any exception on your thread implementation: I see you create a new session with a new graph in each of them, however the session doesn't appear to be used in the thread code. Still, if I remove that "with" statement exceptions are raised regarding tensors not being element of the correct graph. |
Hi, okay thank you for the context. Yes agreed I don't think you need a global network then; just to keep the main weights in the global thread. The reason why it is not erroring within the thread is because of this line: From tensorflow's documentation:
Without the |
Thank you! |
Good! Thanks! |
I have a problem with using Keras models in threads. I have been able to reproduce the issue in a MWE that I report below.
Here and in the real code, I create the networks outside the threads, then inside of them i use a network to make predictions and fitting.
To make it work with as few as 2 threads, I hade to use the workarounds described in:
#5896
and
#6124
(i.e. they amount to using
with graph.as_default():
andnet._make_train_function()
)The program executes without any problem with 32 threads, however when that number is increased to 64 exceptions are raised (see below).
Some of the exceptions raised with 64 threads are:
In the real code much less threads (8) are needed for this behaviour to occur, maybe because each thread uses many networks (6) and they have way more parameters (~3 millions).
Why is this happening?
The text was updated successfully, but these errors were encountered: