1.Environment

In [32]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

1) Qwen2Config 配置类

In [33]:
class Qwen2Config:
    """
    配置类，用于存储并传递模型超参数。
    可根据需要自行修改默认值。
    """
    def __init__(
        self,
        vocab_size=30522,
        hidden_size=768,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        max_position_embeddings=2048,
        rope_theta=10000.0,
        attention_dropout=0.1,
        hidden_act="silu",          # MLP激活函数
        attention_bias=False,       # 是否在Q,K,V投影时使用bias
        rms_norm_eps=1e-6,
        pad_token_id=0,
        num_key_value_heads=4,      # GQA相关(可改)
        _attn_implementation="eager"  # 注意力实现方式: eager, flash_attention_2, sdpa
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.intermediate_size = intermediate_size
        self.max_position_embeddings = max_position_embeddings
        self.rope_theta = rope_theta
        self.attention_dropout = attention_dropout
        self.hidden_act = hidden_act
        self.attention_bias = attention_bias
        self.rms_norm_eps = rms_norm_eps
        self.pad_token_id = pad_token_id
        self._attn_implementation = _attn_implementation
        
        # 一些后续可能用到的属性
        self.gradient_checkpointing = False


2) 预训练模型基类 (简化版)

In [34]:
class Qwen2PreTrainedModel(nn.Module):
    """
    一个简化的预训练模型基类，仅做示例。
    """
    def __init__(self, config: Qwen2Config):
        super().__init__()
        self.config = config

    def init_weights(self):
        """
        简化的权重初始化逻辑，也可使用 xavier_uniform、kaiming_uniform 等更复杂初始化。
        """
        for name, param in self.named_parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    def _backward_compatibility_gradient_checkpointing(self):
        """
        兼容一些老版本或者transformers内部的梯度检查点设定。
        """
        self.gradient_checkpointing = self.config.gradient_checkpointing

    def post_init(self):
        """
        初始化结束后的函数，一般用于权重初始化和其他兼容性检查。
        """
        self.init_weights()
        self._backward_compatibility_gradient_checkpointing()


3) Qwen2模型主体

In [35]:
class Qwen2Model(Qwen2PreTrainedModel):
    def __init__(self, config: Qwen2Config):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        # 词向量Embedding
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        
        # Decoder层，存 num_hidden_layers 个
        self.layers = nn.ModuleList(
            [Qwen2DecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
        )

        # 输出的层归一化
        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        
        # 是否启用gradient_checkpointing(节省显存但会重复计算)
        self.gradient_checkpointing = config.gradient_checkpointing

        # 执行权重初始化以及兼容性检查
        self.post_init()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        output_hidden_states=False,
        output_attentions=False,
        use_cache=False,
        past_key_value=None
    ):
        """
        前向传播
        ----------
        参数:
            input_ids: [batch_size, seq_len]
            attention_mask: [batch_size, 1, seq_len, seq_len]，下三角或其他mask
            position_ids: 如果不传，会默认按顺序 0,1,2...
            output_hidden_states: 是否输出每层的hidden_states
            output_attentions: 是否输出每层的attention_weights
            use_cache: 推理阶段常用，是否使用并返回 kv_cache
            past_key_value: 前面步骤缓存的 kv
        """
        if input_ids is None:
            raise ValueError("Please provide input_ids")

        # 准备保存
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        
        # token -> Emb
        inputs_embeds = self.embed_tokens(input_ids)   # (bsz, seq_len, hidden_size)
        hidden_states = inputs_embeds

        # 进入decoder层
        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=None if past_key_value is None else past_key_value[idx],
                output_attentions=output_attentions,
                use_cache=use_cache,
            )

            hidden_states = layer_outputs[0]
            # 如果 use_cache=True，会返回 present_key_value
            if use_cache:
                # 第2个就是 present_key_value
                present_key_value = layer_outputs[2]
        
        # norm层
        hidden_states = self.norm(hidden_states)

        # 如果需要输出每层hs
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        # 返回值可自行封装，这里做简单返回
        outputs = (hidden_states,)
        if output_hidden_states:
            outputs += (all_hidden_states,)
        if output_attentions:
            outputs += (all_attentions,)
        if use_cache:
            outputs += (present_key_value,)

        return outputs


4) Qwen2DecoderLayer 解码层

In [36]:
class Qwen2DecoderLayer(nn.Module):
    def __init__(self, config: Qwen2Config, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
        self.mlp = Qwen2MLP(config)
        # 两个RMSNorm
        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False
    ):
        # 1) 自注意力
        residual = hidden_states
        # RMSNorm
        hidden_states = self.input_layernorm(hidden_states)
        attn_output, attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )
        hidden_states = residual + attn_output

        # 2) MLP
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)
        if output_attentions:
            outputs += (attn_weights,)
        if use_cache:
            outputs += (present_key_value,)

        return outputs


5) 注意力实现: Qwen2Attention (eager 版本)

In [37]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    将KV进行GQA扩展: hidden_states形状为 [bs, num_key_value_heads, seq_len, head_dim]
    需要扩展为 [bs, num_key_value_heads * n_rep, seq_len, head_dim]
    """
    batch_size, num_heads_kv, seq_len, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    # 用expand+reshape方式，而不是repeat，可以更省内存
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch_size, num_heads_kv, n_rep, seq_len, head_dim
    )
    return hidden_states.reshape(batch_size, num_heads_kv * n_rep, seq_len, head_dim)


In [38]:
def rotate_half(x):
    """
    对张量最后一维前后两半部分进行旋转拼接:
    x[..., : mid] -> -x[..., mid:]
    x[..., mid:] -> x[..., : mid]
    """
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


In [39]:
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    """
    q,k: [batch_size, num_heads, seq_len, head_dim]
    position_ids: [batch_size, seq_len]
    cos,sin: [max_seq_len, head_dim]
    """
    # -> [batch_size, seq_len, head_dim]
    cos = cos[position_ids]
    sin = sin[position_ids]

    # 在 num_heads 位置上插一个维度 => [batch_size, 1, seq_len, head_dim]
    cos = cos.unsqueeze(1)
    sin = sin.unsqueeze(1)

    # 之后和 q,k 相乘 => [batch_size, num_heads, seq_len, head_dim]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


In [40]:
class Qwen2RotaryEmbedding(nn.Module):
    """
    RoPE的旋转位置编码实现。
    """
    def __init__(
        self,
        dim,
        max_position_embeddings=2048,
        base=10000.0,
        device=None
    ):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base

        # 计算 inv_freq: [dim//2]
        # 形如 1/(base^(2i/dim))
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32).float() / self.dim)
        )
        if device is not None:
            inv_freq = inv_freq.to(device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # 提前构建cos,sin缓存
        self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32)

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=dtype)
        t = t.unsqueeze(-1)  # shape: (seq_len, 1)

        # outer product: (seq_len, dim//2)
        freqs = t * self.inv_freq
        # cos,sin shape: (seq_len, dim//2)
        emb = torch.cat([freqs, freqs], dim=-1)  # 在dim=-1拼接

        # 注册缓存
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        """
        x: [batch_size, num_heads, seq_len, head_dim]
        seq_len: 当前需要位置编码的长度
        """
        if seq_len is None:
            seq_len = x.shape[2]

        # 如果 seq_len 超过当前缓存，则重新初始化
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(x.dtype),
            self.sin_cached[:seq_len].to(x.dtype),
        )


In [41]:
class Qwen2Attention(nn.Module):
    """
    Multi-Head Self Attention, 采用RoPE位置编码, 以及GQA的可选实现。
    """
    def __init__(self, config: Qwen2Config, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx

        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads

        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads

        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.attention_dropout = config.attention_dropout
        self.is_causal = True  # 一般decoder是causal

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size必须能被num_heads整除，"
                f"当前 hidden_size={self.hidden_size}, num_heads={self.num_heads}"
            )
        
        # Q,K,V投影层
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        # out投影层
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)

        # RoPE
        self.rotary_emb = Qwen2RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions=False,
        use_cache=False,
    ):
        # hidden_states: [batch_size, seq_len, hidden_size]
        bsz, q_len, _ = hidden_states.size()

        # Q,K,V
        query_states = self.q_proj(hidden_states)  # [bsz, seq_len, num_heads*head_dim]
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # 多头展开
        # Q: [bsz, num_heads, seq_len, head_dim]
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        # K,V: [bsz, num_key_value_heads, seq_len, head_dim]
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # RoPE
        kv_seq_len = q_len  # 这里假设不需要past_kv
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
        if position_ids is None:
            bsz, seq_len = input_ids.shape
            position_ids = (
                torch.arange(seq_len, device=input_ids.device)
                .unsqueeze(0)         # [1, seq_len]
                .expand(bsz, -1)      # [bsz, seq_len]
            )
        q_embed, k_embed = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        query_states, key_states = q_embed, k_embed

        # GQA, 重复KV
        key_states = repeat_kv(key_states, self.num_key_value_groups)    # (bsz, num_heads, seq_len, head_dim)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # 计算注意力得分
        # [bsz, num_heads, q_len, k_len]
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
        
        # 加上mask(比如causal mask)
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask
        
        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)

        # 与values相乘
        attn_output = torch.matmul(attn_weights, value_states)
        # 形状还原
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.num_heads * self.head_dim)
        # out投影
        attn_output = self.o_proj(attn_output)

        # present_key_value在推理时会用到
        present_key_value = (key_states, value_states) if use_cache else None

        outputs = (attn_output, attn_weights if output_attentions else None, present_key_value)

        return outputs


6) Qwen2MLP (MLP部分)

In [42]:
ACT2FN = {
    "silu": torch.nn.functional.silu,
    "relu": torch.nn.functional.relu,
    "gelu": torch.nn.functional.gelu,
}

In [43]:
class Qwen2MLP(nn.Module):
    def __init__(self, config: Qwen2Config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size

        # 两个并行投影: gate_proj / up_proj
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        # 再一个下投影
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        # 类似SwiGLU: (gate激活) * (up线性) -> 下投影
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


7) Qwen2RMSNorm (RMSNorm实现)

In [44]:
class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        RMSNorm实现
        公式: x * weight / sqrt(mean(x^2) + eps)
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        # 先转fp32计算，再转回原dtype
        variance = hidden_states.float().pow(2).mean(-1, keepdim=True)
        normed = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return (self.weight * normed).to(input_dtype)


8) 额外：登记Attention实现类

In [45]:
class Qwen2FlashAttention2(Qwen2Attention):
    """
    此处仅示例，如果需要可结合FlashAttention2实现，暂留空。
    """
    def __init__(self, config: Qwen2Config, layer_idx: int):
        super().__init__(config, layer_idx)

class Qwen2SdpaAttention(Qwen2Attention):
    """
    此处仅示例，如果需要可结合torch.sdpa计算，暂留空。
    """
    def __init__(self, config: Qwen2Config, layer_idx: int):
        super().__init__(config, layer_idx)

QWEN2_ATTENTION_CLASSES = {
    "eager": Qwen2Attention,  # 这里就是本例子中常用的
    "flash_attention_2": Qwen2FlashAttention2,
    "sdpa": Qwen2SdpaAttention,
}

9) 演示：随机输入，跑一遍模型

In [46]:
if __name__ == "__main__":
    # 1) 创建配置
    config = Qwen2Config(
        vocab_size=1000,
        hidden_size=256,
        num_hidden_layers=4,
        num_attention_heads=4,
        intermediate_size=1024,
        max_position_embeddings=128,
        rope_theta=10000.0,
        attention_dropout=0.1,
        hidden_act="silu",
        attention_bias=True,
        rms_norm_eps=1e-6,
        pad_token_id=0,
        num_key_value_heads=2,   # 测试GQA，实际中可改
        _attn_implementation="eager"
    )

    # 2) 实例化模型
    model = Qwen2Model(config)

    # 3) 准备随机输入：batch=2，seq_len=10
    input_ids = torch.randint(0, config.vocab_size, (2, 10))

    # 构造一个简单的下三角mask (causal mask)
    seq_len = input_ids.size(1)
    causal_mask = torch.full((seq_len, seq_len), float("-inf"))
    causal_mask = torch.triu(causal_mask, diagonal=1)  # 上三角置-inf
    # 扩展到 [batch_size, num_heads, seq_len, seq_len]
    causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)

    # 4) 前向传播
    outputs = model(
        input_ids=input_ids,
        attention_mask=causal_mask,
        output_hidden_states=True,
        output_attentions=False,
        use_cache=False
    )

    # 5) 查看输出
    print("last hidden_states shape:", outputs[0].shape)
    if len(outputs) > 1:
        # all_hidden_states
        print("number of hidden_states (including embedding):", len(outputs[1]))
        for idx, hs in enumerate(outputs[1]):
            print(f"  hidden_states[{idx}].shape = {hs.shape}")
    
    # 也可以根据需要，对 outputs 做进一步处理(如再加一个Linear输出logits等)。
    # 由于是简化演示，这里仅打印形状。


last hidden_states shape: torch.Size([2, 10, 256])
number of hidden_states (including embedding): 5
  hidden_states[0].shape = torch.Size([2, 10, 256])
  hidden_states[1].shape = torch.Size([2, 10, 256])
  hidden_states[2].shape = torch.Size([2, 10, 256])
  hidden_states[3].shape = torch.Size([2, 10, 256])
  hidden_states[4].shape = torch.Size([2, 10, 256])
