New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training siamese (biencoder) based transformer model with gradient checkpointing throws error #23801
Comments
cc @ArthurZucker and @younesbelkada |
@sachinya00 What does your code look like, including training setup and training args? |
I've updated the post with the code to reproduce the same |
Hey, thanks for providing a reproduction script. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
The issue is very similar to the below one and I'm not able to make it work even with _set_static_graph() |
System Info
PyTorch Lightning Version 1.6.5
Torch 1.13.0
Python version 3.8
CUDA Version: 11.4
4 NVIDIA A100-SXM4-40GBs
transformers 4.24.0
Reproduction
After adding
model.gradient_checkpointing_enable()
to the training code, throwing below errorRuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations
The workaround to fix this is add
use_reentrant=False
in the below file.https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L600
What's the best way to fix this? instead of adding the above flag manually in the source code
Expected behavior
adding
model.gradient_checkpointing_enable()
shouldn't throw any errorCode to reproduce
The text was updated successfully, but these errors were encountered: