You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've been able to solve an OOM on a TPU v3-8 with an ugly hack that I don't understand.
I feel like it has to do with flushing the memory.
Problem you have encountered:
When running my training script on a TPU v3-8, I get RuntimeError: RESOURCE_EXHAUSTED.
What you expected to happen:
Due to my quick hack (see below), it should run with no problem.
Logs, error messages, etc:
2021-11-03 15:31:06.338128: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2085] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Attempting to reserve 11.35G at the bottom of memory. That was not possible. There are 12.18G free, 0B reserved, and 10.52G reservable.
Epoch ... (1/6): 0%| | 0/6 [03:02<?, ?it/s]
Traceback (most recent call last):
File "/home/koush/dalle-mini/dev/seq2seq/run_seq2seq_flax.py", line 991, in <module>
main()
File "/home/koush/dalle-mini/dev/seq2seq/run_seq2seq_flax.py", line 962, in main
state, train_metric = p_train_step(state, batch)
File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/_src/api.py", line 1946, in cache_miss
out_tree, out_flat = f_pmapped_(*args, **kwargs)
File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/_src/api.py", line 1825, in f_pmapped
out = pxla.xla_pmap(
File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/core.py", line 1698, in bind
return call_bind(self, fun, *args, **params)
File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/core.py", line 1701, in process
return trace.process_map(self, fun, tracers, params)
File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 723, in xla_pmap_impl
return compiled_fun(*args)
File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 1264, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 11.35G at the bottom of memory. That was not possible. There are 12.18G free, 0B reserved, and 10.52G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/koush/dalle-mini/dev/seq2seq/run_seq2seq_flax.py", line 991, in <module>
main()
File "/home/koush/dalle-mini/dev/seq2seq/run_seq2seq_flax.py", line 962, in main
state, train_metric = p_train_step(state, batch)
File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 1264, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 11.35G at the bottom of memory. That was not possible. There are 12.18G free, 0B reserved, and 10.52G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
model parameters are initialized by default (in the huggingface/transformers library)
then we replace them with parameters from a checkpoint
the memory gets fragmented, somehow printing the parameters fixes it
Another solution I had considered is move all the weights to CPU and move them back to TPU again. Is there a cleaner way to handle TPU memory allocation?
This is quite odd for sure. Fragmentation and being close to the limit in terms of memory could off course result in errors that appear almost randomly. One thing you could try is to initialize the model on CPU jax.jit(model.init, backend="cpu") The params are moved to TPU automatically during training or during replication of the state (eg jax_utils.replicate)
Hi,
I've been able to solve an OOM on a TPU v3-8 with an ugly hack that I don't understand.
I feel like it has to do with flushing the memory.
Problem you have encountered:
When running my training script on a TPU v3-8, I get
RuntimeError: RESOURCE_EXHAUSTED
.What you expected to happen:
Due to my quick hack (see below), it should run with no problem.
Logs, error messages, etc:
How do I solve it?
print(model.params)
hereNote:
Steps to reproduce:
git clone https://github.com/borisdayma/dalle-mini
, make sure you're on commit0cc04f208d6218aa63165ed732f41a013f4a8698
pip install -e path_to_repo
cd dev/seq2seq
python run_seq2seq_flax.py --dataset_repo_or_path dalle-mini/encoded-vqgan_imagenet_f16_16384 --train_file data/train/*.jsonl --validation_file data/valid/*.jsonl --len_train 129832935 --len_eval 171505 --eval_steps 1000 --from_checkpoint dalle-mini/dalle-mini/model-1e6bsdiv:latest --streaming --normalize_text --output_dir output --per_device_train_batch_size 56 --per_device_eval_batch_size 56 --preprocessing_num_workers 80 --warmup_steps 5000 --gradient_accumulation_steps 8 --do_train --do_eval --adafactor --num_train_epochs 6 --log_model --learning_rate 0.005
The text was updated successfully, but these errors were encountered: