Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] use all_gather for 10X OSD consolidation speedup #595

Merged
merged 23 commits into from Apr 13, 2021

Conversation

sshleifer
Copy link
Contributor

@sshleifer sshleifer commented Apr 8, 2021

TLDR: Using all_gather instead of broadcast for optimizer state consolidation appears to be a speed win without a memory cost.

Approach:

  • for tensor state in OSD we use all_gather
  • for non-tensor metadata (loss_scale, param_groups, num_padded) we use broadcast.
  • The same OSD (for 300M param model) takes 110 ms to consolidate vs 2300 ms, speedup usually around 10X.
  • assume recipient_rank=0, since there are no other callers.

Evidence of Win:

  • This appears not to use extra GPU RAM, and save a lot of time. For the unittests, the process goes from 2300 MS to consolidate a larger MOE to 110 Ms.
  • For a large fairseq 2.2B param model (1.2T config), it takes 80s instead of 800s, with less CUDA usage.
  • fairseq resumption test (resume training from ckpt should result as same loss as training from scratch) passes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 8, 2021
@sshleifer sshleifer changed the title [FSDP/Prototype] [FSDP/Prototype] use all gather for OSD consolidation Apr 8, 2021

return non_tensor_state, tensor_state

def _gather_optim_state(self, sd_state: Dict[int, Dict[str, Any]]) -> Dict[int, Dict[str, List]]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new _all_gather logic

Base automatically changed from move-placeholder-cpu to master April 8, 2021 23:07
@sshleifer sshleifer closed this Apr 9, 2021
@sshleifer sshleifer reopened this Apr 9, 2021
@sshleifer sshleifer marked this pull request as ready for review April 9, 2021 01:21
@sshleifer sshleifer changed the title [FSDP/Prototype] use all gather for OSD consolidation [FSDP] use all_gather for 10X OSD consolidation speedup Apr 9, 2021
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
fairscale/nn/data_parallel/fully_sharded_data_parallel.py Outdated Show resolved Hide resolved
@@ -627,15 +627,18 @@ def __init__(self, group, wrapper_config, checkpoint_act=False, delay_before_fre

# "expert" params are different on each rank
torch.manual_seed(42 + group.rank())
d_expert = 16
expert = nn.Linear(d_expert, 4)
d_expert = 23
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure we unpad expert params correctly.

@sshleifer
Copy link
Contributor Author

Planning to merge 10am PT tomorrow, barring further comments.

Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have time to fully review this, but this looks good at a high level.

"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. Multiple parameter groups are not yet supported.

This should be called only on the root FSDP instance.
Nested FSDP instances are supported as long as they have the same world_size as the parent or world_size=1.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

@min-xu-ai
Copy link
Contributor

cc @QuentinDuval @prigoyal FYI

@min-xu-ai
Copy link
Contributor

Planning to merge 10am PT tomorrow, barring further comments.

sorry for the delay, I was out for most of the last week.

new_sd = {"state": new_state, "param_groups": copy.deepcopy(sd["param_groups"])}
for k in sd.keys(): # if there are extra keys, like loss_scale, don't delete them
if k not in {"state", "param_groups", "uncollected_local_ids", "param_id_map"}:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm slightly uneasy about this falling out of sync with line 160. Thoughts on some way to enforce parity between them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module level constant + comment + assert

# These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS = {"state", "param_groups", "uncollected_local_ids", "param_id_map"}
...
assert set(unflat_optim_state_dict.keys()) == UNFLAT_RETURN_KEYS 

Copy link
Contributor

@myleott myleott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@sshleifer sshleifer added the FSDP FullyShardedDataParallel (zero-3) label Apr 13, 2021
@sshleifer sshleifer merged commit a82825d into master Apr 13, 2021
@sshleifer sshleifer deleted the all-gather-impl branch April 13, 2021 15:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. FSDP FullyShardedDataParallel (zero-3)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants