Skip to content

Fix EP + DeepSpeed ZeRO-3 loading via accelerate launch#45548

Open
AmineDiro wants to merge 1 commit intohuggingface:mainfrom
AmineDiro:fix-deepspeed-ep-init
Open

Fix EP + DeepSpeed ZeRO-3 loading via accelerate launch#45548
AmineDiro wants to merge 1 commit intohuggingface:mainfrom
AmineDiro:fix-deepspeed-ep-init

Conversation

@AmineDiro
Copy link
Copy Markdown
Member

@AmineDiro AmineDiro commented Apr 21, 2026

Issue

Expert Parallelism (DistributedConfig(enable_expert_parallel=True)) hangs during model loading when launched through accelerate launch with a DeepSpeed ZeRO-3 config. EP works on its own (via torchrun) and ZeRO-3 works on its own — but the two conflict inside from_pretrained because every ZeRO-3 code path is gated on a single env-driven flag (is_deepspeed_zero3_enabled()), and EP needs the non-ZeRO-3 path at every one of those gates.

When you run EP through accelerate launch with DeepSpeed ZeRO-3, the environment variable makes is_deepspeed_zero3_enabled() return True everywhere. Every gate takes the ZeRO-3 path. But EP is fundamentally incompatible with ZeRO-3's initialization, as they shard weights in completely different ways:

  • ZeRO-3: creates lazy partitioned params via deepspeed.zero.Init(), then loads weights through GatheredParameters (all-gather before writing, re-partition after).
  • EP: creates a model on the meta device, registers sharding hooks via distribute_model(), then loads weights through the standard path where shard_and_distribute_module slices each expert tensor by EP rank.

These two can't coexist in the same loading flow. ZeRO-3's lazy params break EP's sharding hooks. EP's meta tensors break ZeRO-3's GatheredParameters (which expects ds_id, ds_shape attributes).

Fix

This PR routes EP through the standard (non-zero3) path inside from_pretrained, lets distribute_model() shard experts as usual, and then lets deepspeed.initialize() wrap the already-loaded, already-sharded model afterward.

  1. get_init_context accepts distributed_config; when EP+DS, use meta device (not zero.Init, not real tensors). Meta allocation is free, and init_weights() is skipped; checkpoint weights overwrite everything anyway.
  2. from_pretrained clears device_map set by initialize_tensor_parallelism when EP+DS. EP needs all ranks to read all shard files for the hooks, so we skip the accelerate dispatch split.
  3. _load_pretrained_model when EP+DS, skips the zero3 loading branch and uses the standard convert_and_load_state_dict_in_model path, passing model.tp_plan (the property, which returns the EP plan when EP is on) instead of model._tp_plan.
  4. _move_missing_keys_from_meta_to_device when EP+DS, does not early-return; runs the standard path to move meta buffers (inv_freq, etc.) to CPU.
  5. _initialize_missing_keys when EP+DS, uses standard initialize_weights() (no GatheredParameters, since params are real/empty, not ZeRO-3-partitioned).

Test

Minimal EP + DeepSpeed ZeRO-3 verification. Smulates accelerate launch by setting HfDeepSpeedConfig so
is_deepspeed_zero3_enabled() returns True and the signal that made from_pretrained hang before the fix.

Run on 4xH100:

import os
import torch
import torch.distributed as dist
import deepspeed
from deepspeed import comm as ds_comm

from transformers import AutoModelForCausalLM
from transformers.integrations.deepspeed import HfDeepSpeedConfig
from transformers.distributed.configuration_utils import DistributedConfig

local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])

dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)

ds_config = {
    "train_batch_size": world_size,
    "train_micro_batch_size_per_gpu": 1,
    "bf16": {"enabled": True},
    "zero_optimization": {"stage": 3, "overlap_comm": True, "contiguous_gradients": True},
}
_dschf = HfDeepSpeedConfig(ds_config)  # strong ref; the global weakref dies if GC'd

mesh = dist.init_device_mesh("cuda", (world_size,))
model = AutoModelForCausalLM.from_pretrained(
    "openai/gpt-oss-20b",
    dtype=torch.bfloat16,
    distributed_config=DistributedConfig(enable_expert_parallel=True),
    device_mesh=mesh,
    attn_implementation="eager",
)
model = model.to(f"cuda:{local_rank}")

if dist.get_rank() == 0:
    w = model.model.layers[0].mlp.experts.gate_up_proj
    print(f"expert shape per rank (post-EP, pre-DS): {tuple(w.shape)}  (32 experts / EP=4 = 8)")

ds_comm.init_distributed("nccl")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
engine, optimizer, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config)

x = torch.randint(0, 1000, (1, 8), device=f"cuda:{local_rank}")
out = engine(input_ids=x, labels=x.clone(), use_cache=False)
engine.backward(out.loss)
engine.step()

if dist.get_rank() == 0:
    print(f"loss={out.loss.item():.4f}")

dist.destroy_process_group()

Before submitting

  • I confirm that this is not a pure code agent PR.
  • Did you read the contributor
    guideline
    , Pull Request section?

Who can review?

@3outeille @ArthurZucker (distributed / TP / EP implementation)

Route EP through the standard (non-zero3) loading path when both EP
and is_deepspeed_zero3_enabled() are active, then let deepspeed.initialize()
wrap the EP-sharded model afterwards.

- Add PreTrainedModel.has_ep property; use it in tp_plan
- get_init_context: meta device for EP+DS (not zero.Init)
- from_pretrained: clear device_map for EP+DS
- _load_pretrained_model: skip zero3 path for EP+DS, pass model.tp_plan
- _move_missing_keys_from_meta_to_device: do not early-return for EP+DS
- _initialize_missing_keys: standard init (no GatheredParameters) for EP+DS
- configuration_utils: strip distributed_config from serialized config
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants