Skip to content

Fix UnboundLocalError for tp_plan_alt when tp_plan is empty#44540

Merged
3outeille merged 4 commits intohuggingface:mainfrom
YangKai0616:fix-tp_plan
Mar 11, 2026
Merged

Fix UnboundLocalError for tp_plan_alt when tp_plan is empty#44540
3outeille merged 4 commits intohuggingface:mainfrom
YangKai0616:fix-tp_plan

Conversation

@YangKai0616
Copy link
Contributor

Per the title, an error occurs when tp_plan is empty due to here:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/test_moe_tp_ep.py", line 6, in <module>
[rank0]:     model = AutoModelForCausalLM.from_pretrained("openai/gpt-oss-20b" , dtype=torch.bfloat16, tp_plan="auto", use_kernels=True)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/transformers/src/transformers/models/auto/auto_factory.py", line 381, in from_pretrained
[rank0]:     return model_class.from_pretrained(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/transformers/src/transformers/modeling_utils.py", line 4137, in from_pretrained
[rank0]:     loading_info, disk_offload_index = cls._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
[rank0]:                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/transformers/src/transformers/modeling_utils.py", line 4256, in _load_pretrained_model
[rank0]:     loading_info, disk_offload_index = convert_and_load_state_dict_in_model(
[rank0]:                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/workspace/transformers/src/transformers/core_model_loading.py", line 1178, in convert_and_load_state_dict_in_model
[rank0]:     if matched_tp_pattern := tp_plan_alt.search(renamed_key):
[rank0]:                              ^^^^^^^^^^^
[rank0]: UnboundLocalError: cannot access local variable 'tp_plan_alt' where it is not associated with a value

Reproduction script and command:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time


model = AutoModelForCausalLM.from_pretrained("openai/gpt-oss-20b" , dtype=torch.bfloat16, tp_plan="auto", use_kernels=True)
print(model._tp_plan)

tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
messages = [
    {"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)

for i in range(5):
# distributed run
    s1 = time.time()
    outputs = model.generate(**inputs.to(model.device), max_new_tokens=100, do_sample=False)
    s2 = time.time()
    outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
    print(outputs[0])
    print(s2-s1)
torchrun --nproc-per-node=4 test_moe_tp_ep.py

This PR fixes the issue by adding a check for non-empty tp_plan in the weight sharding condition logic.

Hi @ArthurZucker , please help review. Thanks!

@YangKai0616
Copy link
Contributor Author

Hi @3outeille , could you please help review this PR? This is a general question. Thanks!

@3outeille
Copy link
Member

lgtm thanks

@3outeille 3outeille enabled auto-merge March 11, 2026 13:23
@3outeille 3outeille added this pull request to the merge queue Mar 11, 2026
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Merged via the queue into huggingface:main with commit 1723c81 Mar 11, 2026
28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants