-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improving memory utilization of Z2+MoE #2079
Conversation
@@ -59,8 +59,9 @@ def split_params_grads_into_shared_and_expert_params( | |||
return shared_grads, expert_grads | |||
|
|||
|
|||
def split_params_into_different_moe_groups_for_optimizer( | |||
param_groups: Tuple[Dict]) -> Tuple[Dict]: | |||
def split_params_into_different_moe_groups_for_optimizer(param_groups: Tuple[Dict], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tjruwase, please let me know how we would want to offer the user a way to set the max_group_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, let's add an moe section in ds_config. Perhaps, @awan-10 could help with the design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awan-10 Perhaps the moe section can also contain a flag to toggle expert slicing too?
…nto zero2_optim_tiling
Summary
While running experiments, I found that an inordinate amount of memory is used in the gradient upscaling step for high expert to GPU ratios (like 1 or 1/2).
This PR does two things:
deepspeed/moe/utils.py
).deepspeed/runtime/zero/stage_1_and_2.py
)Highlight Result
Prior to this PR a 6.7B base model with 16 experts ran OOM on 32 A100 GPUs (40GB).
With the changes, I am able to run the same model with a peak memory utilization of 31.3 GB. Thus at the bare minimum we are saving 21.75% memory for this model.
Train loss curve for reference
Sanity Checks
Train loss curves before and after the changes match
Batch Times
To create the most pathological case, I set global batch size to 8. Yet, there is no penalty in batch times.
Memory Consumption
The memory saved is expected to increase with increasing model sizes.