In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="5"
device = "cuda" # the device to load the model onto

In [2]:
from transformers import Qwen2ForCausalLM, Qwen2Config, AutoTokenizer, GenerationConfig
import torch

# ============== 1. 修正模型类定义 ==============
class MultiHeadQwen2(Qwen2ForCausalLM):
    def __init__(self, config, num_heads=3):
        # 先初始化原始模型结构
        super().__init__(config)
        
        # 删除原单一头并创建多头
        del self.lm_head
        self.num_heads = num_heads
        self.lm_heads = torch.nn.ModuleList([
            torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
            for _ in range(num_heads)
        ])
        
        # 注册自定义参数
        self.current_head = 0
        self.head_switch_freq = 0

    def prepare_inputs_for_generation(self, *args, **kwargs):
        # 从generation_config获取参数
        gen_config = kwargs.get("generation_config", self.generation_config)
        self.current_head = getattr(gen_config, "head_idx", 0)
        self.head_switch_freq = getattr(gen_config, "head_switch_freq", 0)
        return super().prepare_inputs_for_generation(*args, **kwargs)

    def forward(self, input_ids=None, **kwargs):
        outputs = super().forward(
            input_ids=input_ids,
            output_hidden_states=True,
            **kwargs
        )
        
        # 动态切换逻辑
        if self.head_switch_freq > 0 and not self.training:
            seq_len = input_ids.shape[-1]
            if seq_len % self.head_switch_freq == 0:
                self.current_head = (self.current_head + 1) % self.num_heads
        
        hidden_states = outputs.hidden_states[-1]
        logits = self.lm_heads[self.current_head](hidden_states)
        return (logits,) + outputs[1:]

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# ============== 2. 修正模型加载方法 ==============
def load_model_with_heads(checkpoint, num_heads=3):
    # 单独加载配置
    config = Qwen2Config.from_pretrained(checkpoint)
    
    # 手动初始化模型
    model = MultiHeadQwen2(config, num_heads=num_heads)
    
    # 加载预训练权重（跳过不匹配的 lm_head）
    pretrained = Qwen2ForCausalLM.from_pretrained(checkpoint)
    model.load_state_dict(pretrained.state_dict(), strict=False)
    
    # 复制原始head到所有新头
    original_head = pretrained.lm_head.state_dict()
    for head in model.lm_heads:
        head.load_state_dict(original_head)
    
    return model.to(device)

In [4]:
# ============== 3. 使用示例 ==============
checkpoint = "/data/cuiluyi/resources/models/Qwen/Qwen2.5-Math-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = load_model_with_heads(checkpoint, num_heads=3)

In [5]:
prompt = "Find the value of $x$ that satisfies the equation $4x+5 = 6x+7$."

# CoT
messages = [
    {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)

In [10]:
pretrained = Qwen2ForCausalLM.from_pretrained(checkpoint)

In [None]:
# 示例1：固定头生成
generated_ids = model.generate(
    **model_inputs,
    # generation_config=GenerationConfig(
    #     max_new_tokens=512,
    #     do_sample=True,
    #     temperature=0.7,
    #     # 自定义参数
    #     head_idx=1  # 固定使用第二个头
    # )
)

generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

ValueError: The following `model_kwargs` are not used by the model: ['attention_mask'] (note: typos in the generate arguments will also show up in this list)

: 

In [None]:
# 示例2：动态切换头（每3个token切换一次）
output_dynamic = model.generate(
    tokenizer("Solve 5x-3=12:", return_tensors="pt").to(device),
    generation_config=GenerationConfig(
        max_new_tokens=50,
        do_sample=True,
        top_p=0.9,
        # 自定义参数
        head_switch_freq=3  # 切换频率
    )
)
print(tokenizer.decode(output_dynamic[0]))