You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We run a single model shard in each GPU, with a combination of data and model parallel. We have a few different ways of doing resharding today (i.e. converting a model from X shards to Y shards), which should be consolidated to a single solution that we trust.
Supported model parallelism for input and output can be restricted to 1/2/4/8 (the number can change between input and output). Overall number of shards can be restricted to powers of 2 to start.
reshard_model_parallel.py just calls the reshard_megatron_parts function from stitch_fsdp_ckpt.py and seems to assume that the weights are not flattened and the corresponding model keys are included in the checkpoint, which isn't the case with the released checkpoints.
convert_to_singleton.py does work with the released checkpoints, but it has additional requirements unrelated to the resharding logic (bpe_vocab, bpe_merges, etc.) in order to launch DDP and instantiate an FSDP object via the LegacyTask, which might not be very flexible (see this issue).
What would the checkpoint inputs for the consolidated script look like? Do we need to support all use cases mentioned above?
🚀 Feature Request
We run a single model shard in each GPU, with a combination of data and model parallel. We have a few different ways of doing resharding today (i.e. converting a model from X shards to Y shards), which should be consolidated to a single solution that we trust.
Supported model parallelism for input and output can be restricted to 1/2/4/8 (the number can change between input and output). Overall number of shards can be restricted to powers of 2 to start.
Motivation
We have three ways of resharding today:
Let's consolidate and clean-up the code. This will be useful for having a single code-path to load models later: #78
The text was updated successfully, but these errors were encountered: