Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve universal checkpoint #5289

Merged
merged 21 commits into from Mar 28, 2024

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented Mar 17, 2024

This PR includes the following improvement regarding universal checkpoint.

  • Restoring step

A universal checkpoint saves the training step count taken from the engine. In
#5263, we fixed to always set this count to restore training step count to optimizer's states per-param (optimizer_state['state][param]['step']`) and a param_group. However, this approach does not restore the optimizer's state and param groups precisely due to different behaviors of optimizers.

Torch's Adam doesn't make step in a param groups and only uses optimizer_state['state'][param]['step']. Apex's fused adam only uses step in a param groups. DeepSpeed's fused adam creates step in a param groups and never updates. It only uses optimizer_state['state'][param]['step'].
Consequently, this leads to discrepancies between the restored and original states of the optimizer and param groups.

This PR modifies the restoration process to ensure that the step number in the optimizer's state and param groups matches those in the original setup, effectively aligning the restored and original optimizer states and param groups.

  • Unit tests of DP size scaling

This PR also adds unit tests to verify universal checkpointing. They run training with DP, save a checkpoint, and converts in to a universal checkpoint. Then they load the checkpoint with a different DP size and validate that parameters and the all-gathered (ZeRO 1/2) optimizer states match.

  • Fix bug of loading with load_optimizer_states=False

The loader doesn't load parameters from a universal checkpoint when load_optimizer_states=False. c8c0498 fixes this issue.

@tohtana tohtana marked this pull request as ready for review March 18, 2024 04:14
@tohtana tohtana added this pull request to the merge queue Mar 27, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Mar 27, 2024
@tohtana tohtana added this pull request to the merge queue Mar 27, 2024
@loadams loadams removed this pull request from the merge queue due to a manual request Mar 27, 2024
@tohtana tohtana enabled auto-merge March 28, 2024 08:08
@tohtana tohtana added this pull request to the merge queue Mar 28, 2024
Merged via the queue into microsoft:master with commit c56a4b9 Mar 28, 2024
14 checks passed
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
This PR includes the following improvement regarding universal
checkpoint.

- Restoring step

A universal checkpoint saves the training step count taken from the
engine. In
microsoft#5263, we fixed to always set
this count to restore training step count to optimizer's states
per-param (`optimizer_state['state`][param]['step']`) and a param_group.
However, this approach does not restore the optimizer's state and param
groups precisely due to different behaviors of optimizers.

Torch's Adam doesn't make `step` in a param groups and only uses
`optimizer_state['state'][param]['step']`. Apex's fused adam only uses
`step` in a param groups. DeepSpeed's fused adam creates `step` in a
param groups and never updates. It only uses
`optimizer_state['state'][param]['step']`.
Consequently, this leads to discrepancies between the restored and
original states of the optimizer and param groups.

This PR modifies the restoration process to ensure that the step number
in the optimizer's state and param groups matches those in the original
setup, effectively aligning the restored and original optimizer states
and param groups.

- Unit tests of DP size scaling

This PR also adds unit tests to verify universal checkpointing. They run
training with DP, save a checkpoint, and converts in to a universal
checkpoint. Then they load the checkpoint with a different DP size and
validate that parameters and the all-gathered (ZeRO 1/2) optimizer
states match.

- Fix bug of loading with `load_optimizer_states=False`

The loader doesn't load parameters from a universal checkpoint when
`load_optimizer_states=False`.
microsoft@c8c0498
fixes this issue.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants