Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Universal checkpoint for zero stage 3 #5475

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

xylian86
Copy link

This PR enables the universal checkpoint for zero stage 3.

Notes:

  • The current implementation supports Data parallelism.
  • Development is ongoing for universal checkpoint Stage 3 with tensor-slicing model parallelism.
  • Pipeline parallelism is not supported by ZeRO Stage 3, and hence is not included in this universal checkpoint implementation.

In this PR:

  • I've updated deepspeed/checkpoint/ds_to_universal.py to support converting Zero checkpoints into Universal checkpoints.
  • I've updated deepspeed/runtime/zero/stage3.py to enable loading Universal checkpoints using the Stage 3 optimizer.

@xylian86
Copy link
Author

xylian86 commented Apr 29, 2024 via email

@tjruwase tjruwase requested review from samadejacobs, tohtana and lekurile and removed request for mrwyattii May 2, 2024 23:05
Copy link
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a test for universal checkpointing. It currently supports DP scaling only but it would be good to test ZeRO3 feature using this test. You can just add "3" to the test argument.

def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights, param_shapes):
self.load_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes)

def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have similar function for ZeRO 1/2. Can we extract common parts to avoid duplication?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tohtana, thanks for the suggestion!
For those two similar functions for Zero 1/2 and Zero 3, due to the different sharing mechanisms (one is shared across layers while the other is shared within each layer), the implementation of these two functions varies considerably, making it challenging to extract common elements. Please let me know if you have additional suggestions.

deepspeed/checkpoint/ds_to_universal.py Show resolved Hide resolved
return int(text) if text.isdigit() else text


def natural_keys(text):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for introducing this interesting approach.
We have a similar sorting in _merge_zero_shards but are using a different approach but it is not good to have two different sorting implementations for the same purpose. Can you replace this one with natural_keys?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions!
For this natural_keys function, I actually reuse it from zero_to_fp32.py.

You’re right; it’s not ideal to have two different implementations for the same function. How about I replace the one in _merge_zero_shards with this natural_keys?

total_params += 1

partitioned_numel, _ = _zero_partitioned_param_info(unpartitioned_numel, world_size)
print(key_tensors[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove debug statement!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion and sorry for keeping the debug statement. I have removed this line in the new commit.

# self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather

def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights, param_shapes):
self.load_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this function implemented? wrong name?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I figured out that the function name was not updated when the code was merged. I have fixed it in the new commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants