Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable universal checkpoint for zero stage 1 #4516

Merged
merged 14 commits into from
Oct 25, 2023
Merged

Conversation

tjruwase
Copy link
Contributor

@tjruwase tjruwase commented Oct 16, 2023

Generalize universal checkpointing in DS:

  1. Enable for zero stage 1
  2. Move conversion script into DS

Tested with Megatron-DS GPT using companion PR microsoft/Megatron-DeepSpeed#265

Fix #2921

  • readthedocs
  • [ ] Tutorial: Defer

@tjruwase
Copy link
Contributor Author

@stas00, FYI

@tjruwase tjruwase changed the title Enable uni_ckpt for z1 Enable universal checkpoint for zero stage 1 Oct 16, 2023
@stas00
Copy link
Contributor

stas00 commented Oct 16, 2023

Amazing! Thank you for starting to work on this super-essential feature, Tunji!

deepspeed/runtime/zero/stage_1_and_2.py Outdated Show resolved Hide resolved
@tjruwase tjruwase added this pull request to the merge queue Oct 25, 2023
Merged via the queue into master with commit 8fdd9b3 Oct 25, 2023
15 checks passed
baodii pushed a commit to baodii/DeepSpeed that referenced this pull request Nov 7, 2023
* Enable uni_ckpt for z1

* Remove logging fix to seperate PR. Relocate conversion script to avoid logging circular import issue

* Formatting fix

* PR feedback

* Handle replicated params

* Detect bf16_optimizer

* Docs

* Fix docs
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
* Enable uni_ckpt for z1

* Remove logging fix to seperate PR. Relocate conversion script to avoid logging circular import issue

* Formatting fix

* PR feedback

* Handle replicated params

* Detect bf16_optimizer

* Docs

* Fix docs
os.makedirs(param_base_path, exist_ok=True)

cnt += 1
counter = f"{dp_index:0>2d}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dear @tjruwase ,

I'm currently examining a scenario where the maximum dp_index is 127, which, according to alphabetical order, is considered smaller than 13. This raises a question regarding the tensor sorting process in line 144 of our code. Given this context, could there potentially be any issues with how the tensors are sorted due to this ordering?

I appreciate your insight on this matter.

Best regards,
Junfeng

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A specific example is as follows:

./temp/model.transformer_encoder.layers.19.self_attn.in_proj_weight => fp32.100, torch.Size([2187604])
./temp/model.transformer_encoder.layers.19.self_attn.in_proj_weight => fp32.99, torch.Size([48144044])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rgtjf, thanks for sharing this issue. Do you mind creating a new ticket for it? I can see that line 130 may not generalize to larger scales. It would be great if you could share more details in a new ticket. Thanks!

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjruwase I've opened a ticket to track the issue, #5283. Should there be any missing details or additional information required, please don't hesitate to let me know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[REQUEST] universal checkpoint for ZeRO - 1,2,3
5 participants