Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fit_generator converges slower and tensorboard's OOM problem #4797

Closed
rabintang opened this issue Dec 22, 2016 · 8 comments
Closed

fit_generator converges slower and tensorboard's OOM problem #4797

rabintang opened this issue Dec 22, 2016 · 8 comments

Comments

@rabintang
Copy link

rabintang commented Dec 22, 2016

problem 1:
i changed my code from fit to fit_generator, however, i found it converged slower than using fit. Using fit it converges at 5th epoch, however, it doesn't converge even at 10th epoch when i change to fit_generator.
Moreover, when validation_data also uses fit_generator, the performance becomes much worse. Is my generator function has errors?

def batch_iter(x, y, batch_size):
"""
Generates a batch iterator for a dataset.
"""
data_size = len(y)
num_batches_per_epoch = int(len(y)/batch_size) if data_size%batch_size == 0
else int(len(y)/batch_size) + 1
while 1:
# Shuffle data
shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices]
# Shuffle the data at each epoch
for batch_num in range(num_batches_per_epoch):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
yield (x_shuffled[start_index:end_index],y_shuffled[start_index:end_index])

problem 2:
Added tensorboard callback, it always causes OOM (out of memory) problem, even i change to fit_generator. Seems, it will allocate a memory size as big as training data's.

Epoch 00000: val_acc improved from -inf to 0.12241, saving model to models/weights-best-label9-00-0.12.hdf5
Traceback (most recent call last):
File "trainGraph.py", line 165, in
class_weight=class_weights)
File "/home/tensorflow/.local/lib/python2.7/site-packages/keras/models.py", line 934, in fit_generator
initial_epoch=initial_epoch)
File "/home/tensorflow/.local/lib/python2.7/site-packages/keras/engine/training.py", line 1555, in fit_generator
callbacks.on_epoch_end(epoch, epoch_logs)
File "/home/tensorflow/.local/lib/python2.7/site-packages/keras/callbacks.py", line 43, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "/home/tensorflow/.local/lib/python2.7/site-packages/keras/callbacks.py", line 561, in on_epoch_end
result = self.sess.run([self.merged], feed_dict=feed_dict)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 766, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 964, in _run
feed_dict_string, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1014, in _do_run
target_list, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1034, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.ResourceExhaustedError: OOM when allocating tensor with shape[45551,2392,20]
[[Node: Gather = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/gpu:0"](embedding_1_W/read, _recv_embedding_input_1_0/_119)]]
[[Node: cond/Merge/_123 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_31_cond/Merge", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]]

@Vladimir-Yashin
Copy link

Vladimir-Yashin commented Dec 24, 2016

Please take a look:

https://github.com/fchollet/keras/blob/master/keras/callbacks.py#L551:

# TODO: implement batched calls to sess.run
# (current call will likely go OOM on GPU)

Also check out the rest of the code in that function.
Tensorboard callback passes validation data again through the network on epoch end and it doesn't use batches, so if your validation data is bigger than few samples it is very likely to cause OOM on GPU.

I've submitted a pull request with the fix:
#4834

PS: I recommend you to open a separate Issue for a problem 1

Vladimir-Yashin pushed a commit to Vladimir-Yashin/keras that referenced this issue Dec 25, 2016
With current implementation Tensorboard Callback passes whole
validation_data through sess.run() at once and this causes OOM on GPU
for bigger datasets or leads to much higher memory footprint.

If validation data is split into batches, then it would require to:
- split data into batches
- pass each batch through sess.run and save the result as summary_str
  (serialized Summary object)
- somehow manually take apart these objects and average all histograms
  manually
- prepare aggregated Summary object and write to Tensorboard log file

Instead of doing that my approach is simpler:
- sample batch_size worth of data points from validation_data
- run sess.run() once

This may lead to few problems:
- histograms won't be 100% accurate since not all data is taken into
  account
- histograms will slightly vary even if weights didn't change between
  epochs just because each time Tensorboard callback is engaged it will
  pick a different set of samples to process
- the smaller the batch_size is, the more pronounced those effects are
- when validation_data is smaller than batch_size some samples are going
  to be used multiple times and some others may not be used at all

However the benefit is worth it, Tensorboard callback won't lead to huge
memory footprint and won't cause OOM crash when whole validation_data
doesn't fit into GPU memory.
Vladimir-Yashin pushed a commit to Vladimir-Yashin/keras that referenced this issue Dec 25, 2016
Current implementation of Tensorboard Callback passes whole
validation_data through sess.run() at once and this causes OOM on GPU
for bigger datasets or leads to much higher memory footprint.

If validation data is split into batches, then it would require to:
- split data into batches
- pass each batch through sess.run and save the result as summary_str
  (serialized Summary object)
- somehow manually take apart these objects and average all histograms
  manually
- prepare aggregated Summary object and write to Tensorboard log file

Instead of doing that my approach is simpler:
- sample batch_size worth of data points from validation_data
- run sess.run() once

This may lead to few problems:
- histograms won't be 100% accurate since not all data is taken into
  account
- histograms will slightly vary even if weights didn't change between
  epochs just because each time Tensorboard callback is engaged it will
  pick a different set of samples to process
- the smaller the batch_size is, the more pronounced those effects are
- when validation_data is smaller than batch_size some samples are going
  to be used multiple times and some others may not be used at all

However the benefit is worth it, Tensorboard callback won't lead to huge
memory footprint and won't cause OOM crash when whole validation_data
doesn't fit into GPU memory.
Vladimir-Yashin added a commit to Vladimir-Yashin/keras that referenced this issue Dec 25, 2016
This is the fix of a problem described in keras-team#4797 (problem2)
@ahundt
Copy link
Contributor

ahundt commented Dec 30, 2016

I've noticed the GPU running out of memory when tensorboard is enabled as well at titu1994/DenseNet#2 which was corrected when I disabled tensorboard.

Is there a way to fix this with fewer downsides than #4834? i.e. reducing memory utilization with more consistency in the tensorboard charts

@Vladimir-Yashin
Copy link

@anundt: The only way I can think of is to split validation data into batches and then pass each batch through the model.

With this approach we're going to get a TensorBoard Summary object for each batch, so if you have, say, 1000 samples in validation_data and batch_size of 100 this would result in 10 Summaries and each summary object is going to be shown in TensorBoard Histogram and Distribution tabs separately.

I bet this is something nobody wants.

We could probably average across multiple Summary objects to form a single Summary object where all values are averages of corresponding values from batch-specific summaries, but this is a much more complex approach. There are no examples of doing it that way and generally Summary object is not supposed to be tampered with.

I checked few examples of code just to see how this problem is resolved and nobody aggregates statistics manually.
Tensorflow official example passes whole validation data through the model at once:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py

It should be possible to pass something like "batch_size for TensorBoard" parameter that would be bigger than batch_size, so that you'd get more accurate stats, but it won't really resolve that problem that you need more GPU memory with TensorBoard enabled.

@ophiry
Copy link

ophiry commented Mar 26, 2017

Still an issue in keras2 and tf.contrib.keras

does the deeper integration in tf.contrib.keras make it easier to solve this? is there some kind of workaround?

@Vladimir-Yashin
Copy link

Current workaround is to disable TensorBoard callback histograms.
If I remember correctly validation is done in batches everywhere except for TensorBoard.

@entelechie
Copy link

Is anyone going to fix the TensorBoard's batches compute?

@stale stale bot added the stale label Jul 26, 2017
@stale
Copy link

stale bot commented Jul 26, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

@stale stale bot closed this as completed Aug 25, 2017
@wgmg165
Copy link

wgmg165 commented Jan 4, 2018

I also get this error! when I traine the same model on a group of datasets, "ResourceExhaustedError: OOM when allocating tensor with shape", I can't fix it, please help me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants