-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] load_checkpoint should load directly to gpu #1971
Comments
As this problem is recurrent for HF Transformers' users - meanwhile I shared a hack to stagger checkpoint loading for those who need here: If you're not using HF Trainer you can patch deepspeed's much later edit - this idea actually doesn't work because of the |
Any update? Follow the suggestion in here to make a large swapfile, but the loading takes forever ... |
Describe the bug
Currently HF Transformers integration users can finetune a model and save the checkpoint with given resources. However resuming from that same checkpoint requires much more CPU peak memory - which can be huge for large models, which prevents users from resuming their finetuning. (The current workaround is to add a huge swap file)
To Reproduce
I reproduced it as part of this bug report: huggingface/transformers#17258
The full reproduction steps are here: huggingface/transformers#17258 (comment)
I also verified that
torch.load
doesn't load everything in CPU memory whenmap_location="cpu"
huggingface/transformers#17258 (comment)and I tracked the issue down to deepspeed loading those potentially huge zero checkpoints (70GB for gpt-j-6) into cpu memory first:
DeepSpeed/deepspeed/runtime/engine.py
Line 2748 in 5208eb7
Expected behavior
save_checkpoint
andload_checkpoint
should require approximately the same amount of memory and should be lean and not need any CPU memory other than the size of the largest param or optim state sincetorch.load
copies params via cpu.With upcoming models like 176B the current implementation just won't work as it would require several TBs of CPU memory to load a zero checkpoint.
@tjruwase, @jeffra
The text was updated successfully, but these errors were encountered: