Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] load_checkpoint should load directly to gpu #1971

Open
stas00 opened this issue May 21, 2022 · 3 comments
Open

[BUG] load_checkpoint should load directly to gpu #1971

stas00 opened this issue May 21, 2022 · 3 comments
Assignees
Labels
bug Something isn't working training

Comments

@stas00
Copy link
Contributor

stas00 commented May 21, 2022

Describe the bug

Currently HF Transformers integration users can finetune a model and save the checkpoint with given resources. However resuming from that same checkpoint requires much more CPU peak memory - which can be huge for large models, which prevents users from resuming their finetuning. (The current workaround is to add a huge swap file)

To Reproduce

I reproduced it as part of this bug report: huggingface/transformers#17258

The full reproduction steps are here: huggingface/transformers#17258 (comment)

I also verified that torch.load doesn't load everything in CPU memory when map_location="cpu" huggingface/transformers#17258 (comment)

and I tracked the issue down to deepspeed loading those potentially huge zero checkpoints (70GB for gpt-j-6) into cpu memory first:

_state = torch.load(ckpt_name, map_location='cpu')

Expected behavior

save_checkpoint and load_checkpoint should require approximately the same amount of memory and should be lean and not need any CPU memory other than the size of the largest param or optim state since torch.load copies params via cpu.

With upcoming models like 176B the current implementation just won't work as it would require several TBs of CPU memory to load a zero checkpoint.

@tjruwase, @jeffra

@stas00
Copy link
Contributor Author

stas00 commented Jun 9, 2022

As this problem is recurrent for HF Transformers' users - meanwhile I shared a hack to stagger checkpoint loading for those who need here:
huggingface/transformers#17534 (comment)

If you're not using HF Trainer you can patch deepspeed's load_checkpoint directly, using similar code - you just need the rank number the deepspeed way there or get it from int(os.environ.get("LOCAL_RANK", "0"))

much later edit - this idea actually doesn't work because of the barrier calls, so staggering is not possible, since the first process won't free up memory on cpu until all other processes loaded it.

@desperadoola
Copy link

Any update?

Follow the suggestion in here to make a large swapfile, but the loading takes forever ...

@desperadoola
Copy link

Any update?

Follow the suggestion in here to make a large swapfile, but the loading takes forever ...

Change 'pin_memory' to False, and follow this #3629 solve the problem. Now we can resume training from a FALCON-40B checkpoint, with 1T CPU memory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

4 participants