Fix EP + DeepSpeed ZeRO-3 loading via accelerate launch#45548
Open
AmineDiro wants to merge 1 commit intohuggingface:mainfrom
Open
Fix EP + DeepSpeed ZeRO-3 loading via accelerate launch#45548AmineDiro wants to merge 1 commit intohuggingface:mainfrom
AmineDiro wants to merge 1 commit intohuggingface:mainfrom
Conversation
Route EP through the standard (non-zero3) loading path when both EP and is_deepspeed_zero3_enabled() are active, then let deepspeed.initialize() wrap the EP-sharded model afterwards. - Add PreTrainedModel.has_ep property; use it in tp_plan - get_init_context: meta device for EP+DS (not zero.Init) - from_pretrained: clear device_map for EP+DS - _load_pretrained_model: skip zero3 path for EP+DS, pass model.tp_plan - _move_missing_keys_from_meta_to_device: do not early-return for EP+DS - _initialize_missing_keys: standard init (no GatheredParameters) for EP+DS - configuration_utils: strip distributed_config from serialized config
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Issue
Expert Parallelism (
DistributedConfig(enable_expert_parallel=True)) hangs during model loading when launched throughaccelerate launchwith a DeepSpeed ZeRO-3 config. EP works on its own (viatorchrun) and ZeRO-3 works on its own — but the two conflict insidefrom_pretrainedbecause every ZeRO-3 code path is gated on a single env-driven flag (is_deepspeed_zero3_enabled()), and EP needs the non-ZeRO-3 path at every one of those gates.When you run EP through
accelerate launchwith DeepSpeed ZeRO-3, the environment variable makesis_deepspeed_zero3_enabled()returnTrueeverywhere. Every gate takes the ZeRO-3 path. But EP is fundamentally incompatible with ZeRO-3's initialization, as they shard weights in completely different ways:deepspeed.zero.Init(), then loads weights throughGatheredParameters(all-gather before writing, re-partition after).distribute_model(), then loads weights through the standard path whereshard_and_distribute_moduleslices each expert tensor by EP rank.These two can't coexist in the same loading flow. ZeRO-3's lazy params break EP's sharding hooks. EP's meta tensors break ZeRO-3's
GatheredParameters(which expectsds_id,ds_shapeattributes).Fix
This PR routes EP through the standard (non-zero3) path inside
from_pretrained, letsdistribute_model()shard experts as usual, and then letsdeepspeed.initialize()wrap the already-loaded, already-sharded model afterward.get_init_contextacceptsdistributed_config; when EP+DS, use meta device (notzero.Init, not real tensors). Meta allocation is free, andinit_weights()is skipped; checkpoint weights overwrite everything anyway.from_pretrainedclearsdevice_mapset byinitialize_tensor_parallelismwhen EP+DS. EP needs all ranks to read all shard files for the hooks, so we skip theacceleratedispatch split._load_pretrained_modelwhen EP+DS, skips the zero3 loading branch and uses the standardconvert_and_load_state_dict_in_modelpath, passingmodel.tp_plan(the property, which returns the EP plan when EP is on) instead ofmodel._tp_plan._move_missing_keys_from_meta_to_devicewhen EP+DS, does not early-return; runs the standard path to move meta buffers (inv_freq, etc.) to CPU._initialize_missing_keyswhen EP+DS, uses standardinitialize_weights()(noGatheredParameters, since params are real/empty, not ZeRO-3-partitioned).Test
Minimal EP + DeepSpeed ZeRO-3 verification. Smulates
accelerate launchby setting HfDeepSpeedConfig sois_deepspeed_zero3_enabled()returns True and the signal that madefrom_pretrainedhang before the fix.Run on 4xH100:
Before submitting
guideline, Pull Request section?
Who can review?
@3outeille @ArthurZucker (distributed / TP / EP implementation)