Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a script to reshard FSDP checkpoints #459

Merged
merged 1 commit into from
Nov 14, 2022

Conversation

tangbinh
Copy link
Contributor

@tangbinh tangbinh commented Oct 26, 2022

Summary

We add a new script to reshard raw FSDP checkpoints as part of our efforts to consolidate the checkpoint resharding logic. This script is a bit more general than some of the existing ones:

  • Compared to reshard_mp.py, it allows us to optionally unflatten model weights and be compatible with the generator interface when ddp-backend is set to pytorch_ddp.
  • Compared to the consolidate_shard_weights and build_unflat_state_dict functions from FSDP (the former is used in stitch_fsdp_ckpt.py), it supports both unsharding and resharding model weights and optimizer states.
  • Compared to checkpoint_utils.py, which is used in convert_to_singleton.py, it doesn't require instantiating FSDP instances and avoid the various requirements that come with it (DDP, vocab files, configs, etc). We also decouple the filename handling to make it a bit more flexible.

Note that this script doesn't include the logic for model parallel resharding. We should probably have a separate script for it, which can be used together with this one.

Testing

  • Run the script to merge the sharded checkpoints of the 2.7B parameters model into one shard for each model parallel part and load the resharded checkpoints with the interactive CLI:
for j in {0..3}; do
    python -m metaseq.scripts.reshard_fsdp \
    --input-glob-pattern "/data/gpt-z/models/gptz/2.7B/raw/checkpoint_last-model_part-$j-shard*.pt" \
    --output-shard-name "/shared/home/binhtang/checkpoints/opt-2.7b/reshard-model_part-$j.pt" \
    --num-output-shards 1 --skip-optimizer-state True --unflatten-weights True;
done
python -m metaseq.cli.interactive_cli
> what is the meaning of life?
To be happy.
  • Run the script to reshard the 6.7B parameters model checkpoint for each model parallel part from 256 shards to 1 shard and from 1 shard back to 256 shards. The sharded checkpoints we get back are almost identical to the original ones except for some rank-specific data that are lost during the first conversion due to rank 0 copies (e.g optimizer_history, extra_state, cfg.distributed_training.distributed_rank).
for j in {0..1}; do
    python -m metaseq.scripts.reshard_fsdp \
    --input-glob-pattern "/data/gpt-z/models/gptz/6.7B/raw/checkpoint_last-model_part-$j-shard*.pt" \
    --output-shard-name "/shared/home/binhtang/checkpoints/opt-6.7b/reshard-model_part-$j.pt" \
    --num-output-shards 1 --skip-optimizer-state False --unflatten-weights False;
done

for j in {0..1}; do
    python -m metaseq.scripts.reshard_fsdp \
    --input-glob-pattern "/shared/home/binhtang/checkpoints/opt-6.7b/reshard-model_part-$j.pt" \
    --output-shard-name "/shared/home/binhtang/checkpoints/opt-6.7b-reshard/checkpoint_last-model_part-$j-shard{i}.pt" \
    --num-output-shards 256 --skip-optimizer-state False --unflatten-weights False;
done
import torch
for i in range(256):
    before = torch.load(f"/data/gpt-z/models/gptz/6.7B/raw/checkpoint_last-model_part-0-shard{i}.pt", map_location=torch.device("cpu"))
    after = torch.load(f"/shared/home/binhtang/checkpoints/opt-6.7b-reshard/checkpoint_last-model_part-0-shard{i}.pt", map_location=torch.device("cpu"))
    assert all(torch.allclose(before["model"][k], after["model"][k]) for k in before["model"].keys())
    assert(before["shard_metadata"] == after["shard_metadata"])
    assert(torch.allclose(x['exp_avg'], y['exp_avg']) for x, y in zip(before['last_optimizer_state']['state'], after['last_optimizer_state']['state']) for key in ('exp_avg', 'exp_avg_sq'))

@tangbinh tangbinh changed the title Reshard fsdp Add a script to reshard FSDP checkpoints Oct 26, 2022
@tangbinh tangbinh marked this pull request as ready for review October 26, 2022 20:20
@tangbinh tangbinh linked an issue Oct 26, 2022 that may be closed by this pull request
@suchenzang
Copy link
Contributor

Before digging in... are we now able to delete at least one of the existing resharding/consolidate scripts 😅 ?

@tangbinh
Copy link
Contributor Author

Before digging in... are we now able to delete at least one of the existing resharding/consolidate scripts 😅 ?

I'm not sure about the deprecation practice for our repo, but I think we can deprecate reshard_mp.py and its related Bash scripts (reshard_mp_launch.sh and reshard_mp_launch_no_slurm.sh) after this PR. The introduced script also overlaps heavily with some functions in checkpoint_utils.py, and I wonder if it makes sense to combine them as well (perhaps after more thorough testing).

The other scripts (reshard_model_parallel.py, convert_to_singleton.py, consolidate_fsdp_shards.py, stitch_fsdp_ckpt.py) have the MP resharding logic and can be addressed later on (I'm working on another script for MP resharding only). What do you think?

@suchenzang
Copy link
Contributor

I would've just yolo deleted but having a semblance of deprecation practice here seems like the right thing to do - would just move all the ones we can deprecate now into some deprecated directory, then delete them in a few months / whenever we remember to clean things out again, lol.

@suchenzang
Copy link
Contributor

Since I'm depending mostly the test plan to stamp this... can you add a quick update to a README somewhere mentioning how / when to use this script (and a placeholder for the model parallel resharding later on)?

@tangbinh
Copy link
Contributor Author

tangbinh commented Oct 27, 2022

would just move all the ones we can deprecate now into some deprecated directory, then delete them in a few months

Good idea! Perhaps we can also add some logging statements in those scripts to let people know about the deprecation?

can you add a quick update to a README somewhere mentioning how / when to use this script

I'll add more docstring in this script then. It doesn't seem to fit well into any of the existing READMEs.

Copy link
Contributor

@ruanslv ruanslv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Could we also test this with output shards > 1? Even better if output shards > input shards, which is something we'll need for big zucchini run

metaseq/scripts/reshard_fsdp.py Show resolved Hide resolved
metaseq/scripts/reshard_fsdp.py Outdated Show resolved Hide resolved
@tangbinh
Copy link
Contributor Author

tangbinh commented Oct 31, 2022

Could we also test this with output shards > 1? Even better if output shards > input shards, which is something we'll need for big zucchini run

Please see the updated test plan for the 256 shards → 1 shard → 256 shards conversion. In contrast to the previous scripts that throw away shard metadata, we're able to chain these resharding operations by updating the metadata after each step.

@suchenzang
Copy link
Contributor

One last sanity check: could we check that the 256 shard -> 1 shard conversion yields a model that generates the same output given the same input? And same for 1 shard -> 256 shard (need some... numerical equivalence checks to make sure we didn't accidentally jumble up the parameters 😅 ). In one of these cases, would also be good to check when we include optimizer state with a reshard operation and confirm that we can successfully resume training from a resharded checkpoint (and that training continues to look "sane").

We can separately note the above in a separate task to tackle and just merge this PR to wrap up this bit for now. The above testing would be super important to confirm / write a test case around, but don't want to block on that.

@tangbinh
Copy link
Contributor Author

tangbinh commented Nov 2, 2022

could we check that the 256 shard -> 1 shard conversion yields a model that generates the same output given the same input? And same for 1 shard -> 256 shard.

@suchenzang For some reason, the sample inputs are different when I change the number of FSDP shards from 8 to 4 to 8 for the 125M model, so I hack the code a bit to confirm that the different checkpoints indeed yield the same outputs. It'd be good to make this into a test, but we can probably do that later on.

torch.manual_seed(0)
sample = {"net_input": {"src_tokens": torch.randint(0, 50000, size=(32, 2048), device=torch.cuda.current_device())}}

# start with 8 FSDP shards
2022-11-02 20:29:13 | INFO | metaseq.trainer | Loaded optim_state for /shared/home/binhtang/src/metaseq/checkpoints/opt-125m/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt
2022-11-02 20:29:13 | INFO | metaseq.trainer | Loaded checkpoint /shared/home/binhtang/src/metaseq/checkpoints/opt-125m/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt (epoch 1 @ 2000 updates)
2022-11-02 20:30:34 | INFO | train_inner | {"epoch": 1, "actv_norm": "67.167", "pos_norm": "0.246", "tok_norm": "1.084", "emb_norm": "1.042", "loss": "18.552", "ppl": "384346", "wps": "1522.9", "ups": "0", "wpb": "524288", "bsz": "256", "num_updates": "2001", "lr": "0.000598785", "gnorm": "9.072", "clip": "100", "loss_scale": "2", "scale_window": "32", "train_wall": "72", "cuda_gb_allocated": "32.4", "cuda_gb_reserved": "44", "cuda_gb_free": "46.8", "wall": "0"}

# convert from 8 shards into 4 shards
2022-11-02 20:33:00 | INFO | metaseq.trainer | Loaded optim_state for /shared/home/binhtang/checkpoints/opt-125m-reshard/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt
2022-11-02 20:33:00 | INFO | metaseq.trainer | Loaded checkpoint /shared/home/binhtang/checkpoints/opt-125m-reshard/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt (epoch 1 @ 2000 updates)
2022-11-02 20:34:12 | INFO | train_inner | {"epoch": 1, "actv_norm": "67.167", "pos_norm": "0.246", "tok_norm": "1.084", "emb_norm": "1.042", "loss": "18.552", "ppl": "384346", "wps": "782.1", "ups": "0", "wpb": "262144", "bsz": "128", "num_updates": "2001", "lr": "0.000598785", "gnorm": "9.072", "clip": "100", "loss_scale": "2", "scale_window": "32", "train_wall": "66", "cuda_gb_allocated": "32.4", "cuda_gb_reserved": "44.1", "cuda_gb_free": "46.8", "wall": "0"}

# convert from 4 shards back to 8 shards
2022-11-02 20:38:45 | INFO | metaseq.trainer | Loaded optim_state for /shared/home/binhtang/checkpoints/opt-125m-reshard-back/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt
2022-11-02 20:38:45 | INFO | metaseq.trainer | Loaded checkpoint /shared/home/binhtang/checkpoints/opt-125m-reshard-back/test_v0.zero2.adam.ngpu16/checkpoint_2000-shard0.pt (epoch 1 @ 2000 updates)
2022-11-02 20:39:54 | INFO | train_inner | {"epoch": 1, "actv_norm": "67.167", "pos_norm": "0.246", "tok_norm": "1.084", "emb_norm": "1.042", "loss": "18.552", "ppl": "384346", "wps": "1580.3", "ups": "0", "wpb": "524288", "bsz": "256", "num_updates": "2001", "lr": "0.000598785", "gnorm": "9.072", "clip": "100", "loss_scale": "2", "scale_window": "32", "train_wall": "59", "cuda_gb_allocated": "32.4", "cuda_gb_reserved": "43.9", "cuda_gb_free": "46.8", "wall": "0"}

In one of these cases, would also be good to check when we include optimizer state with a reshard operation and confirm that we can successfully resume training from a resharded checkpoint (and that training continues to look "sane").

I'm trying to verify this now and will post results back here.

@tangbinh
Copy link
Contributor Author

tangbinh commented Nov 2, 2022

Since the existing scripts reshard_mp.py, reshard_mp_launch.sh and reshard_mp_launch_no_slurm.sh are also duplicated in metaseq-internal, I've decided to simply remove them and update some instructions in the docs (see the updated PR).

@tangbinh
Copy link
Contributor Author

Training appears to continue as expected after we load the resharded checkpoints. For example, here is the loss curve after we shard a checkpoint for the 125M model from 4 shards to 8 shards and reload the new checkpoint from step 2000:
Screen Shot 2022-11-11 at 9 25 43 AM

Copy link
Contributor

@suchenzang suchenzang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Create a consolidated resharding logic
4 participants