Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid OOM on TPU #1658

Closed
borisdayma opened this issue Nov 3, 2021 · 2 comments
Closed

Avoid OOM on TPU #1658

borisdayma opened this issue Nov 3, 2021 · 2 comments

Comments

@borisdayma
Copy link

borisdayma commented Nov 3, 2021

Hi,

I've been able to solve an OOM on a TPU v3-8 with an ugly hack that I don't understand.
I feel like it has to do with flushing the memory.

Problem you have encountered:

When running my training script on a TPU v3-8, I get RuntimeError: RESOURCE_EXHAUSTED.

What you expected to happen:

Due to my quick hack (see below), it should run with no problem.

Logs, error messages, etc:

2021-11-03 15:31:06.338128: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2085] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Attempting to reserve 11.35G at the bottom of memory. That was not possible. There are 12.18G free, 0B reserved, and 10.52G reservable.
Epoch ... (1/6):   0%|                                                                                                                                                                                | 0/6 [03:02<?, ?it/s]
Traceback (most recent call last):                                                                                                                                                                                          
  File "/home/koush/dalle-mini/dev/seq2seq/run_seq2seq_flax.py", line 991, in <module>
    main()
  File "/home/koush/dalle-mini/dev/seq2seq/run_seq2seq_flax.py", line 962, in main
    state, train_metric = p_train_step(state, batch)
  File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/_src/api.py", line 1946, in cache_miss
    out_tree, out_flat = f_pmapped_(*args, **kwargs)
  File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/_src/api.py", line 1825, in f_pmapped
    out = pxla.xla_pmap(
  File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/core.py", line 1698, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/core.py", line 1701, in process
    return trace.process_map(self, fun, tracers, params)
  File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 723, in xla_pmap_impl
    return compiled_fun(*args)
  File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 1264, in execute_replicated
    out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 11.35G at the bottom of memory. That was not possible. There are 12.18G free, 0B reserved, and 10.52G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/koush/dalle-mini/dev/seq2seq/run_seq2seq_flax.py", line 991, in <module>
    main()
  File "/home/koush/dalle-mini/dev/seq2seq/run_seq2seq_flax.py", line 962, in main
    state, train_metric = p_train_step(state, batch)
  File "/home/koush/.pyenv/versions/dev/lib/python3.9/site-packages/jax/interpreters/pxla.py", line 1264, in execute_replicated
    out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: RESOURCE_EXHAUSTED: Attempting to reserve 11.35G at the bottom of memory. That was not possible. There are 12.18G free, 0B reserved, and 10.52G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).

How do I solve it?

  • add print(model.params) here
  • the model is now training 🤯

Note:

  • I tested it a few times to check it was not just a non-deterministic error
  • I also created a bash script that tries running the training script (with no breakpoint) for 100 times but the error would always happen
  • The script runs with no problem if I decrease the batch size

Steps to reproduce:

  • git clone https://github.com/borisdayma/dalle-mini, make sure you're on commit 0cc04f208d6218aa63165ed732f41a013f4a8698
  • pip install -e path_to_repo
  • at root of repo, cd dev/seq2seq
  • python run_seq2seq_flax.py --dataset_repo_or_path dalle-mini/encoded-vqgan_imagenet_f16_16384 --train_file data/train/*.jsonl --validation_file data/valid/*.jsonl --len_train 129832935 --len_eval 171505 --eval_steps 1000 --from_checkpoint dalle-mini/dalle-mini/model-1e6bsdiv:latest --streaming --normalize_text --output_dir output --per_device_train_batch_size 56 --per_device_eval_batch_size 56 --preprocessing_num_workers 80 --warmup_steps 5000 --gradient_accumulation_steps 8 --do_train --do_eval --adafactor --num_train_epochs 6 --log_model --learning_rate 0.005
@borisdayma
Copy link
Author

So I think the issue is due to the following:

  • model parameters are initialized by default (in the huggingface/transformers library)
  • then we replace them with parameters from a checkpoint
  • the memory gets fragmented, somehow printing the parameters fixes it

Another solution I had considered is move all the weights to CPU and move them back to TPU again. Is there a cleaner way to handle TPU memory allocation?

@jheek
Copy link
Member

jheek commented Nov 29, 2021

This is quite odd for sure. Fragmentation and being close to the limit in terms of memory could off course result in errors that appear almost randomly. One thing you could try is to initialize the model on CPU jax.jit(model.init, backend="cpu") The params are moved to TPU automatically during training or during replication of the state (eg jax_utils.replicate)

@google google locked and limited conversation to collaborators Nov 29, 2021
@jheek jheek closed this as completed Nov 29, 2021

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants