-
Notifications
You must be signed in to change notification settings - Fork 721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a script to reshard FSDP checkpoints #459
Conversation
Before digging in... are we now able to delete at least one of the existing resharding/consolidate scripts 😅 ? |
I'm not sure about the deprecation practice for our repo, but I think we can deprecate reshard_mp.py and its related Bash scripts (reshard_mp_launch.sh and reshard_mp_launch_no_slurm.sh) after this PR. The introduced script also overlaps heavily with some functions in checkpoint_utils.py, and I wonder if it makes sense to combine them as well (perhaps after more thorough testing). The other scripts (reshard_model_parallel.py, convert_to_singleton.py, consolidate_fsdp_shards.py, stitch_fsdp_ckpt.py) have the MP resharding logic and can be addressed later on (I'm working on another script for MP resharding only). What do you think? |
I would've just yolo deleted but having a semblance of deprecation practice here seems like the right thing to do - would just move all the ones we can deprecate now into some deprecated directory, then delete them in a few months / whenever we remember to clean things out again, lol. |
Since I'm depending mostly the test plan to stamp this... can you add a quick update to a README somewhere mentioning how / when to use this script (and a placeholder for the model parallel resharding later on)? |
Good idea! Perhaps we can also add some logging statements in those scripts to let people know about the deprecation?
I'll add more docstring in this script then. It doesn't seem to fit well into any of the existing READMEs. |
1cb84c9
to
57ece1c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Could we also test this with output shards > 1? Even better if output shards > input shards, which is something we'll need for big zucchini run
57ece1c
to
a72920c
Compare
Please see the updated test plan for the 256 shards → 1 shard → 256 shards conversion. In contrast to the previous scripts that throw away shard metadata, we're able to chain these resharding operations by updating the metadata after each step. |
One last sanity check: could we check that the 256 shard -> 1 shard conversion yields a model that generates the same output given the same input? And same for 1 shard -> 256 shard (need some... numerical equivalence checks to make sure we didn't accidentally jumble up the parameters 😅 ). In one of these cases, would also be good to check when we include optimizer state with a reshard operation and confirm that we can successfully resume training from a resharded checkpoint (and that training continues to look "sane"). We can separately note the above in a separate task to tackle and just merge this PR to wrap up this bit for now. The above testing would be super important to confirm / write a test case around, but don't want to block on that. |
a72920c
to
b740924
Compare
b740924
to
502c4c8
Compare
@suchenzang For some reason, the sample inputs are different when I change the number of FSDP shards from 8 to 4 to 8 for the 125M model, so I hack the code a bit to confirm that the different checkpoints indeed yield the same outputs. It'd be good to make this into a test, but we can probably do that later on. torch.manual_seed(0)
sample = {"net_input": {"src_tokens": torch.randint(0, 50000, size=(32, 2048), device=torch.cuda.current_device())}}
# start with 8 FSDP shards
2022-11-02 20:29:13 | INFO | metaseq.trainer | Loaded optim_state for /shared/home/binhtang/src/metaseq/checkpoints/opt-125m/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt
2022-11-02 20:29:13 | INFO | metaseq.trainer | Loaded checkpoint /shared/home/binhtang/src/metaseq/checkpoints/opt-125m/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt (epoch 1 @ 2000 updates)
2022-11-02 20:30:34 | INFO | train_inner | {"epoch": 1, "actv_norm": "67.167", "pos_norm": "0.246", "tok_norm": "1.084", "emb_norm": "1.042", "loss": "18.552", "ppl": "384346", "wps": "1522.9", "ups": "0", "wpb": "524288", "bsz": "256", "num_updates": "2001", "lr": "0.000598785", "gnorm": "9.072", "clip": "100", "loss_scale": "2", "scale_window": "32", "train_wall": "72", "cuda_gb_allocated": "32.4", "cuda_gb_reserved": "44", "cuda_gb_free": "46.8", "wall": "0"}
# convert from 8 shards into 4 shards
2022-11-02 20:33:00 | INFO | metaseq.trainer | Loaded optim_state for /shared/home/binhtang/checkpoints/opt-125m-reshard/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt
2022-11-02 20:33:00 | INFO | metaseq.trainer | Loaded checkpoint /shared/home/binhtang/checkpoints/opt-125m-reshard/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt (epoch 1 @ 2000 updates)
2022-11-02 20:34:12 | INFO | train_inner | {"epoch": 1, "actv_norm": "67.167", "pos_norm": "0.246", "tok_norm": "1.084", "emb_norm": "1.042", "loss": "18.552", "ppl": "384346", "wps": "782.1", "ups": "0", "wpb": "262144", "bsz": "128", "num_updates": "2001", "lr": "0.000598785", "gnorm": "9.072", "clip": "100", "loss_scale": "2", "scale_window": "32", "train_wall": "66", "cuda_gb_allocated": "32.4", "cuda_gb_reserved": "44.1", "cuda_gb_free": "46.8", "wall": "0"}
# convert from 4 shards back to 8 shards
2022-11-02 20:38:45 | INFO | metaseq.trainer | Loaded optim_state for /shared/home/binhtang/checkpoints/opt-125m-reshard-back/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt
2022-11-02 20:38:45 | INFO | metaseq.trainer | Loaded checkpoint /shared/home/binhtang/checkpoints/opt-125m-reshard-back/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt (epoch 1 @ 2000 updates)
2022-11-02 20:39:54 | INFO | train_inner | {"epoch": 1, "actv_norm": "67.167", "pos_norm": "0.246", "tok_norm": "1.084", "emb_norm": "1.042", "loss": "18.552", "ppl": "384346", "wps": "1580.3", "ups": "0", "wpb": "524288", "bsz": "256", "num_updates": "2001", "lr": "0.000598785", "gnorm": "9.072", "clip": "100", "loss_scale": "2", "scale_window": "32", "train_wall": "59", "cuda_gb_allocated": "32.4", "cuda_gb_reserved": "43.9", "cuda_gb_free": "46.8", "wall": "0"}
I'm trying to verify this now and will post results back here. |
Since the existing scripts reshard_mp.py, reshard_mp_launch.sh and reshard_mp_launch_no_slurm.sh are also duplicated in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary
We add a new script to reshard raw FSDP checkpoints as part of our efforts to consolidate the checkpoint resharding logic. This script is a bit more general than some of the existing ones:
ddp-backend
is set topytorch_ddp
.consolidate_shard_weights
andbuild_unflat_state_dict
functions from FSDP (the former is used in stitch_fsdp_ckpt.py), it supports both unsharding and resharding model weights and optimizer states.convert_to_singleton.py
, it doesn't require instantiating FSDP instances and avoid the various requirements that come with it (DDP, vocab files, configs, etc). We also decouple the filename handling to make it a bit more flexible.Note that this script doesn't include the logic for model parallel resharding. We should probably have a separate script for it, which can be used together with this one.
Testing
optimizer_history
,extra_state
,cfg.distributed_training.distributed_rank
).