Skip to content

Conversation

jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Aug 29, 2025

The indices has been masked by 0, but the 0 will be recognized as experts[0]. We need a new class specific for masking

I run gpt-oss with EP=2, and found both rank0 and rank1 computed expert 0
image

After this PR, we can see the masking expert is num_expert (16) here, and 16 will be skipped.
image

@Rocketknight1
Copy link
Member

cc @ArthurZucker I think?

@jiqing-feng
Copy link
Contributor Author

run-slow: gpt_oss

@jiqing-feng
Copy link
Contributor Author

I also need to change all the other MOE, like mixtral, after this is verified.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@ArthurZucker
Copy link
Collaborator

You should not need for now because only gpt_oss supports EP for now!@

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Sep 1, 2025

To reproduce the error:
Please run this command and script on a CPU-only device.
mpirun -np 2 --map-by ppr:1:numa --bind-to numa -genv MASTER_ADDR=127.0.0.1 -genv MASTER_PORT=29500 -genv OMP_NUM_THREADS=32 python tp_hf.py

import os
import torch
import torch.distributed as dist
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from transformers.distributed import DistributedConfig


model_id = "lmsys/gpt-oss-20b-bf16"
os.environ['RANK'] = str(os.environ.get('PMI_RANK', 0))
os.environ['LOCAL_RANK'] = str(os.environ.get('PMI_RANK', 0))
os.environ['WORLD_SIZE'] = str(os.environ.get('PMI_SIZE', 1))
rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])

def main(rank, world_size) -> None:
    is_tp = world_size > 1
    model_kwargs = dict(torch_dtype=torch.bfloat16)
    if is_tp:
        model_kwargs["tp_plan"] = "auto"
        # model_kwargs["distributed_config"] = DistributedConfig(enable_expert_parallel=1)
    else:
        model_kwargs["device_map"] = "cpu"

    # Retrieve tensor parallel model
    model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
    if dist.is_initialized():
        print("Backend:", dist.get_backend())
    else:
        print("Distributed process group is not initialized.")


    # Retrieve tensor parallel model
    config = AutoConfig.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(model_id, config=config,  **model_kwargs)
    if dist.is_initialized():
        print("Backend:", dist.get_backend())
    else:
        print("Distributed process group is not initialized.")

    # Prepare input tokens
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    prompt = "Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun."
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, do_sample=False, max_new_tokens=32)

    if rank == 0:
        print(tokenizer.batch_decode(outputs, skip_special_tokens=True))


if __name__ == "__main__":
    rank = int(os.environ["RANK"]) if "RANK" in os.environ else 0
    world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    main(rank, world_size)

Output before this PR:

["Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun. She was a very good friend, and she was a very good\n\nIt sounds like you're sharing a story or a prompt! If you'd like to continue the story"]

Output after this PR:

['Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun. She was a very good friend, and she was a good friend. She was a good friend. She was a good\n\nIt seems like your text got cut']

Output without EP: python tp_hf.py

['Once upon a time, there existed a little girl, who liked to have adventures. She wanted to go to places and meet new people, and have fun. She was a very good friend, and she was a good friend. She was a good friend. She was a good\n\nIt seems like your message got cut']

You can see that the output is almost the same as without EP after this PR.

cc @SunMarc @ArthurZucker

@jiqing-feng jiqing-feng marked this pull request as ready for review September 1, 2025 02:10
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Sep 2, 2025

Hi @ArthurZucker @SunMarc @Rocketknight1 . Would you please review this PR? I've copied the codes to reproduce the issue. You can run it under the cpu-only torch without transformers customized kernels. Installing a cpu-only torch pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu

@jiqing-feng
Copy link
Contributor Author

I also added base_model_ep_plan in configuration_gpt_oss.py so the distributed_config = DistributedConfig(enable_expert_parallel=1) can work. Otherwise it has no parallelism implemented as ep_plan is None.

@jiqing-feng
Copy link
Contributor Author

Hi @SunMarc . Would you please review this PR? We have other tasks on gpt-oss model which blocked by this PR. Waiting for your review! Thanks!

@jiqing-feng jiqing-feng force-pushed the gpt-oss-ep branch 2 times, most recently from b6687d1 to cce133f Compare September 5, 2025 08:30
@jiqing-feng
Copy link
Contributor Author

I read this blog and followed the instruction of distributed_config=DistributedConfig(enable_expert_parallel=1), but the ep_plan is None in gpt-oss model.

I don't know how cuda performs the EP, maybe because cuda uses kernels. But CPU definitely needs to pass ep_plan if we want to enable EP. So I added base_model_ep_plan in configuration_gpt_oss.py

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@ArthurZucker
Copy link
Collaborator

Sorry I was off for a week!

@jiqing-feng
Copy link
Contributor Author

Sorry I was off for a week!

@ArthurZucker . No worries. This is a bug fix and the case can be easily reproduced on CPU. Please review this PR as we have a blog pending on this PR being merged, so we can release.

@jiqing-feng
Copy link
Contributor Author

The failed CI is not related to my changes.

@ArthurZucker
Copy link
Collaborator

Don't worry I'll do the last review today and merge!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, let's just add a small explanation int he doc and good to go!

Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: gpt_oss, mxfp4

@jiqing-feng
Copy link
Contributor Author

Hi @ArthurZucker . I have fixed your comments. Please review it. Thanks!

@ArthurZucker
Copy link
Collaborator

Thanks 🤗

@ArthurZucker ArthurZucker merged commit 3340ccb into huggingface:main Sep 10, 2025
19 of 21 checks passed
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
vijayabhaskar-ev pushed a commit to vijayabhaskar-ev/transformers that referenced this pull request Oct 2, 2025
* fix out shape

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix router indice

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix mod

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix masking

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix typo

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix typo

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add safety cheking

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix checking

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* enable 1 expert per rank

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix skip

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add ep plan in config

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add update ep plan

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix typo

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* rm ep_plan and add comments

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request Oct 4, 2025
* fix out shape

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix router indice

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix mod

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix masking

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix typo

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix typo

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add safety cheking

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix checking

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* enable 1 expert per rank

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix skip

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add ep plan in config

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add update ep plan

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix typo

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* rm ep_plan and add comments

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants