# Transformer高级概念示例

## Transformer各组件变体

### 自注意力变体示例

#### 标准多头自注意力

In [None]:
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        # 对Q, K, V分别进行线性投影
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        
        # 计算线性投影
        Q = self.W_q(x)  # [batch_size, seq_len, d_model]
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 分头
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)  # [batch_size, num_heads, seq_len, d_head]
        K = K.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        
        # 缩放点积注意力计算
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_head)  # [batch_size, num_heads, seq_len, seq_len]
        attn = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn, V)  # [batch_size, num_heads, seq_len, d_head]
        
        # 拼接各头
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)  # [batch_size, seq_len, d_model]
        
        # 输出投影
        output = self.W_o(context)
        return output

# 使用示例
d_model, num_heads, seq_len, batch_size = 512, 8, 64, 2
x = torch.rand(batch_size, seq_len, d_model)
mha = MultiHeadAttention(d_model, num_heads)
output = mha(x)
print("MHA输出形状:", output.shape)  # [batch_size, seq_len, d_model]

#### GQA示例

GQA 将 $ n_h $ 个查询头分成 $ n_g $ 组，每组共享一组 $ K_j, V_j $

In [None]:
import torch
import torch.nn as nn
import math

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_groups):
        super(GroupedQueryAttention, self).__init__()
        
        # 参数校验
        # 确保模型维度d_model可被头数num_heads整除，以便分割为多个注意力头
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        # 确保头数num_heads可被分组数num_groups整除，确保每组分配均匀的查询头
        assert num_heads % num_groups == 0, "num_heads must be divisible by num_groups"
        
        # 保存模型参数
        self.d_model = d_model  # 输入和输出的隐藏状态维度（如 512）
        self.num_heads = num_heads  # 注意力头数（如8），决定并行处理的子空间数
        self.num_groups = num_groups  # 分组数（如4），决定键和值向量的共享组数
        self.d_head = d_model // num_heads  # 每个头的维度（如512/8=64）
        self.heads_per_group = num_heads // num_groups  # 每组的查询头数（如8/4=2）

        # 定义线性投影层
        # 查询（Query）投影矩阵，将输入映射到所有头的维度
        self.W_q = nn.Linear(d_model, d_model)  # 形状 [d_model, d_model]
        # 键（Key）投影矩阵，仅为num_groups组生成键向量，减少参数量
        self.W_k = nn.Linear(d_model, self.d_head * num_groups)  # 形状 [d_model, d_head * num_groups]
        # 值（Value）投影矩阵，同样仅为num_groups组生成值向量
        self.W_v = nn.Linear(d_model, self.d_head * num_groups)  # 形状 [d_model, d_head * num_groups]
        # 输出投影矩阵，将多头输出映射回d_model维度
        self.W_o = nn.Linear(d_model, d_model)  # 形状 [d_model, d_model]

    def forward(self, x):
        # batch_size：批量大小，seq_len：序列长度，d_model：隐藏状态维度
        batch_size, seq_len, d_model = x.size()
        
        # 线性投影，生成查询、键和值向量
        Q = self.W_q(x)  # 查询向量，[batch_size, seq_len, d_model]
        K = self.W_k(x)  # 键向量，[batch_size, seq_len, d_head * num_groups]
        V = self.W_v(x)  # 值向量，[batch_size, seq_len, d_head * num_groups]
        
        # 分割查询、键和值向量
        # 将查询向量Q分割为num_heads个头，形状变为 [batch_size, num_heads, seq_len, d_head]
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        # 将键向量K分割为num_groups个组，形状变为 [batch_size, num_groups, seq_len, d_head]
        K = K.view(batch_size, seq_len, self.num_groups, self.d_head).transpose(1, 2)
        # 将值向量V分割为num_groups个组，形状变为 [batch_size, num_groups, seq_len, d_head]
        V = V.view(batch_size, seq_len, self.num_groups, self.d_head).transpose(1, 2)
        
        # 扩展键和值向量以匹配查询头数
        # 生成组索引，重复 heads_per_group 次，映射每组到对应的查询头
        # 例如：num_groups=4，heads_per_group=2，生成 [0,0,1,1,2,2,3,3]
        group_idx = torch.arange(self.num_groups).repeat_interleave(self.heads_per_group)
        # 通过索引扩展 K，使每组键向量分配给 heads_per_group 个查询头
        K = K[:, group_idx, :, :]  # [batch_size, num_heads, seq_len, d_head]
        # 同样扩展 V，使每组值向量分配给 heads_per_group 个查询头
        V = V[:, group_idx, :, :]  # [batch_size, num_heads, seq_len, d_head]
        
        # 缩放点积注意力
        # 计算注意力分数：Q 与 K 的点积，除以 sqrt(d_head) 以稳定梯度
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_head)  # [batch_size, num_heads, seq_len, seq_len]
        # 应用 softmax 归一化，得到注意力权重
        attn = torch.softmax(scores, dim=-1)  # [batch_size, num_heads, seq_len, seq_len]
        # 使用注意力权重加权值向量，生成上下文向量
        context = torch.matmul(attn, V)  # [batch_size, num_heads, seq_len, d_head]
        
        # 拼接多头输出
        # 将上下文向量转置回 [batch_size, seq_len, num_heads, d_head]，并展平为 [batch_size, seq_len, d_model]
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        # 输出投影，将多头输出映射回原始维度
        output = self.W_o(context)  # [batch_size, seq_len, d_model]
        return output

# 示例用法
# 设置参数：模型维度 512，8 个注意力头，4 个分组，序列长度 64，批量大小 2
d_model, num_heads, num_groups, seq_len, batch_size = 512, 8, 4, 64, 2
# 初始化 GQA 模型
gqa = GroupedQueryAttention(d_model, num_heads, num_groups)
# 生成随机输入张量
x = torch.rand(batch_size, seq_len, d_model)
# 执行前向传播
output = gqa(x)
# 打印输出形状，预期为 [batch_size, seq_len, d_model]
print("GQA的输出形状:", output.shape)  # [batch_size, seq_len, d_model]

#### MQA示例

MQA 使用单一 $ K, V $，所有 $ n_h $ 个查询头共享

In [None]:
import torch
import torch.nn as nn
import math

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiQueryAttention, self).__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        # 线性投影
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, self.d_head)  # Single K projection
        self.W_v = nn.Linear(d_model, self.d_head)  # Single V projection
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        
        # 线性投影
        Q = self.W_q(x)  # [batch_size, seq_len, d_model]
        K = self.W_k(x)  # [batch_size, seq_len, d_head]
        V = self.W_v(x)  # [batch_size, seq_len, d_head]
        
        # 将Q分为查询头，并扩展K和V
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)  # [batch_size, num_heads, seq_len, d_head]
        K = K.unsqueeze(1).expand(-1, self.num_heads, -1, -1)  # [batch_size, num_heads, seq_len, d_head]
        V = V.unsqueeze(1).expand(-1, self.num_heads, -1, -1)  # [batch_size, num_heads, seq_len, d_head]
        
        # 缩放点积注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_head)  # [batch_size, num_heads, seq_len, seq_len]
        attn = torch.softmax(scores, dim=-1)
        context = torch.matmul(attn, V)  # [batch_size, num_heads, seq_len, d_head]
        
        # 拼接各头
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        # 输出投影
        output = self.W_o(context)
        return output

# Example usage
d_model, num_heads, seq_len, batch_size = 512, 8, 64, 2
mqa = MultiQueryAttention(d_model, num_heads)
x = torch.rand(batch_size, seq_len, d_model)
output = mqa(x)
print("MQA的输出形状:", output.shape)  # [batch_size, seq_len, d_model]

#### MLA示例

- MLA通过下投影生成低维潜在向量 $ C^{KV}, C^Q \in \mathbb{R}^{L \times d_{\text{latent}}} $，仅缓存 $ C^{KV} $
- 解耦RoPE将潜在向量分为非位置（nope）和位置（rope）部分，简化实现仅对 rope 部分应用旋转
  - RoPE指的是Rotary Position Embedding（旋转位置编码），是一种在Transformer模型中广泛使用的相对位置编码方法
  - MLA中，RoPE被进一步优化为解耦RoPE（Decoupled RoPE），以适应MLA的低秩压缩机制，确保位置信息在压缩后的潜在空间中能够有效融入注意力计算
- 计算复杂度增加（下投影和上投影），内存复杂度为 $ O(L \cdot d_{\text{latent}}) $，远低于MHA和GQA

**投影层**

- 下投影 $ W_{\text{dkv}}, W_{\text{dq}} $ 将高维输入压缩到低维潜在空间，减少内存占用。
- 上投影 $ W_{\text{uq}}, W_{\text{uk}}, W_{\text{uv}} $ 恢复多头维度，保持表达能力。
- 输出投影 $ W_o $ 整合多头输出。

**前向传播**

- 下投影：将输入 $ x $ 压缩为 $ C_{\text{kv}}, C_q $，形状为 $ [batch_size, seq_len, d_{\text{latent}}] $。
- 应用 RoPE：对 $ C_{\text{kv}}, C_q $ 应用解耦 RoPE，融入位置信息。
- 上投影：将潜在向量恢复为多头 $ Q, K, V $，形状为 $ [batch_size, num_heads, seq_len, d_head] $。
- 注意力计算：执行缩放点积注意力，生成上下文向量。
- 输出整合：拼接多头输出并通过 $ W_o $ 投影，返回最终输出和 $ C_{\text{kv}} $（用于推理缓存）。

In [None]:
import torch
import torch.nn as nn
import math

class MultiHeadLatentAttention(nn.Module):
    def __init__(self, d_model, num_heads, d_latent, d_rope=64):
        super(MultiHeadLatentAttention, self).__init__()
        
        # 参数校验
        # 确保模型维度d_model可被头数num_heads整除，以便分割为多个注意力头
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        # 确保潜在维度d_latent足够大以包含RoPE维度d_rope
        assert d_latent >= d_rope, "d_latent must be at least d_rope"
        # 确保RoPE维度为偶数，因为RoPE需要成对处理维度
        assert d_rope % 2 == 0, "d_rope must be even for RoPE"
        
        # 保存模型参数
        self.d_model = d_model  # 输入和输出的隐藏状态维度
        self.num_heads = num_heads  # 注意力头数
        self.d_head = d_model // num_heads  # 每个头的维度
        self.d_latent = d_latent  # 潜在空间维度，用于压缩键和值
        self.d_rope = d_rope  # 用于RoPE的位置编码维度
        self.d_nope = d_latent - d_rope  # 非位置部分（nope）的维度

        # 下投影层：将输入从d_model压缩到低维潜在空间d_latent
        self.W_dkv = nn.Linear(d_model, d_latent)  # 键和值的共享下投影矩阵
        self.W_dq = nn.Linear(d_model, d_latent)   # 查询的下投影矩阵
        
        # 上投影层：将潜在向量从d_latent恢复到多头维度d_model
        self.W_uq = nn.Linear(d_latent, d_model)  # 查询上投影矩阵
        self.W_uk = nn.Linear(d_latent, d_model)  # 键上投影矩阵
        self.W_uv = nn.Linear(d_latent, d_model)  # 值上投影矩阵
        
        # 输出投影层：将多头输出映射回d_model维度
        self.W_o = nn.Linear(d_model, d_model)
        
        # 初始化RoPE的频率参数，用于位置编码
        self.freqs = self._init_rope_frequencies(d_rope // 2)

    def _init_rope_frequencies(self, dim):
        # 初始化RoPE的频率参数
        # 使用公式 1 / (10000^(2i/d)) 生成频率，dim为d_rope // 2
        # torch.arange(0, dim, 1) 生成0到 dim-1 的序列，代表频率索引
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 1).float() / dim))
        return inv_freq  # 返回频率张量，形状为 [dim]

    def apply_decoupled_rope(self, x, positions):
        # 应用解耦 RoPE，将潜在向量分为非位置（nope）和位置（rope）部分，仅对 rope 部分应用旋转
        # 输入 x: [batch_size, seq_len, d_latent]，positions: [seq_len]
        batch_size, seq_len, _ = x.size()
        
        # 将潜在向量分割为非位置部分和位置部分
        x_nope = x[..., :self.d_nope]  # 非位置部分，[batch_size, seq_len, d_nope]
        x_rope = x[..., self.d_nope:]  # 位置部分，[batch_size, seq_len, d_rope]
        
        # 准备 RoPE 的频率参数
        freqs = self.freqs.to(x.device)  # 将频率移到输入设备（如 GPU）
        t = positions[:, None].float()  # 位置索引扩展为 [seq_len, 1]
        freqs = t * freqs[None, :]  # 计算旋转角度，[seq_len, d_rope // 2]
        cos_freq = torch.cos(freqs)  # 余弦值，[seq_len, d_rope // 2]
        sin_freq = torch.sin(freqs)  # 正弦值，[seq_len, d_rope // 2]
        
        # 将 rope 部分分成两半，用于成对旋转
        x1 = x_rope[..., :self.d_rope // 2]  # 前半部分，[batch_size, seq_len, d_rope // 2]
        x2 = x_rope[..., self.d_rope // 2:]  # 后半部分，[batch_size, seq_len, d_rope // 2]
        
        # 应用 RoPE 旋转
        # 旋转公式：(x1, x2) -> (x1 * cos - x2 * sin, x1 * sin + x2 * cos)
        x_rope_rotated = torch.cat([
            x1 * cos_freq - x2 * sin_freq,  # 旋转后的前半部分
            x1 * sin_freq + x2 * cos_freq   # 旋转后的后半部分
        ], dim=-1)  # 拼接，[batch_size, seq_len, d_rope]
        
        # 拼接非位置部分和旋转后的位置部分
        return torch.cat([x_nope, x_rope_rotated], dim=-1)  # [batch_size, seq_len, d_latent]

    def forward(self, x):
        # 前向传播，输入 x: [batch_size, seq_len, d_model]
        batch_size, seq_len, d_model = x.size()
        
        # 下投影：将输入压缩到低维潜在空间
        C_kv = self.W_dkv(x)  # 键和值潜在向量，[batch_size, seq_len, d_latent]
        C_q = self.W_dq(x)    # 查询潜在向量，[batch_size, seq_len, d_latent]
        
        # 应用解耦 RoPE，为潜在向量添加位置信息
        positions = torch.arange(seq_len, device=x.device)  # 位置索引 [0, 1, ..., seq_len-1]
        C_kv = self.apply_decoupled_rope(C_kv, positions)  # 对键和值潜在向量应用 RoPE
        C_q = self.apply_decoupled_rope(C_q, positions)    # 对查询潜在向量应用 RoPE
        
        # 上投影：将潜在向量恢复到多头维度
        Q = self.W_uq(C_q).view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)  # 查询，[batch_size, num_heads, seq_len, d_head]
        K = self.W_uk(C_kv).view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)  # 键，[batch_size, num_heads, seq_len, d_head]
        V = self.W_uv(C_kv).view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)  # 值，[batch_size, num_heads, seq_len, d_head]
        
        # 缩放点积注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_head)  # 注意力分数，[batch_size, num_heads, seq_len, seq_len]
        attn = torch.softmax(scores, dim=-1)  # 注意力权重，归一化
        context = torch.matmul(attn, V)  # 上下文向量，[batch_size, num_heads, seq_len, d_head]
        
        # 拼接多头输出
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)  # [batch_size, seq_len, d_model]
        
        # 输出投影
        output = self.W_o(context)  # 最终输出，[batch_size, seq_len, d_model]
        return output, C_kv  # 返回输出和键-值潜在向量（用于缓存）

# 示例用法
d_model, num_heads, d_latent, d_rope, seq_len, batch_size = 512, 8, 128, 64, 64, 2
mla = MultiHeadLatentAttention(d_model, num_heads, d_latent, d_rope)
x = torch.rand(batch_size, seq_len, d_model)
output, C_kv = mla(x)
print("MLA的输出形状:", output.shape)  # [batch_size, seq_len, d_model]
print("MLA KV cache的形状:", C_kv.shape)  # [batch_size, seq_len, d_latent]

### 局部注意力机制

In [None]:
import torch
import torch.nn as nn
import math

class LocalAttention(nn.Module):
    def __init__(self, d_model, num_heads, window_size):
        super(LocalAttention, self).__init__()
        
        # 参数校验
        # 确保模型维度d_model可被头数num_heads整除，以便分割为多个注意力头
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        # 保存模型参数
        self.d_model = d_model  # 输入和输出的隐藏状态维度（如 512）
        self.num_heads = num_heads  # 注意力头数（如8），决定并行处理的子空间数
        self.d_head = d_model // num_heads  # 每个头的维度（如 512/8=64）
        self.window_size = window_size  # 滑动窗口的半宽（左右各 window_size 个 token）

        # 定义线性投影层
        # 查询（Query）投影矩阵，将输入映射到所有头的维度
        self.W_q = nn.Linear(d_model, d_model)  # 形状 [d_model, d_model]
        # 键（Key）投影矩阵，生成所有头的键向量
        self.W_k = nn.Linear(d_model, d_model)  # 形状 [d_model, d_model]
        # 值（Value）投影矩阵，生成所有头的值向量
        self.W_v = nn.Linear(d_model, d_model)  # 形状 [d_model, d_model]
        # 输出投影矩阵，将多头输出映射回d_model维度
        self.W_o = nn.Linear(d_model, d_model)  # 形状 [d_model, d_model]

    def forward(self, x):
        # 前向传播，输入 x: [batch_size, seq_len, d_model]
        # batch_size：批量大小，seq_len：序列长度，d_model：隐藏状态维度
        batch_size, seq_len, d_model = x.size()
        
        # 线性投影，生成查询、键和值向量
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)  # 查询，[batch_size, num_heads, seq_len, d_head]
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)  # 键，[batch_size, num_heads, seq_len, d_head]
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)  # 值，[batch_size, num_heads, seq_len, d_head]
        
        # 初始化上下文向量，用于存储注意力计算结果
        context = torch.zeros_like(Q)  # [batch_size, num_heads, seq_len, d_head]
        
        # 滑动窗口局部注意力
        for i in range(seq_len):
            # 定义当前token i的窗口范围
            # 窗口覆盖 [i-window_size, i+window_size]，并处理边界情况
            start = max(0, i - self.window_size)  # 窗口起始位置，最小为0
            end = min(seq_len, i + self.window_size + 1)  # 窗口结束位置，最大为seq_len
            
            # 提取局部窗口内的键和值向量
            K_local = K[:, :, start:end, :]  # [batch_size, num_heads, window_size, d_head]
            V_local = V[:, :, start:end, :]  # [batch_size, num_heads, window_size, d_head]
            
            # 获取当前token i的查询向量
            q_i = Q[:, :, i:i+1, :]  # [batch_size, num_heads, 1, d_head]
            
            # 计算局部注意力分数
            # q_i与K_local点积，除以sqrt(d_head)稳定梯度
            scores = torch.matmul(q_i, K_local.transpose(-2, -1)) / math.sqrt(self.d_head)  # [batch_size, num_heads, 1, window_size]
            
            # 应用softmax归一化，得到局部注意力权重
            attn = torch.softmax(scores, dim=-1)  # [batch_size, num_heads, 1, window_size]
            
            # 使用注意力权重加权局部值向量，生成上下文向量
            context[:, :, i:i+1, :] = torch.matmul(attn, V_local)  # [batch_size, num_heads, 1, d_head]
        
        # 拼接多头输出
        # 将上下文向量转置回 [batch_size, seq_len, num_heads, d_head]，并展平为 [batch_size, seq_len, d_model]
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        
        # 输出投影，将多头输出映射回原始维度
        output = self.W_o(context)  # [batch_size, seq_len, d_model]
        return output

# 示例用法
# 设置参数：模型维度512，8个注意力头，窗口半宽16，序列长度64，批量大小2
d_model, num_heads, window_size, seq_len, batch_size = 512, 8, 16, 64, 2
# 初始化局部注意力模型
local_attn = LocalAttention(d_model, num_heads, window_size)
# 生成随机输入张量
x = torch.rand(batch_size, seq_len, d_model)
# 执行前向传播
output = local_attn(x)
# 打印输出形状，预期为[batch_size, seq_len, d_model]
print("Local Attention的输出形状:", output.shape)  # [batch_size, seq_len, d_model]

## 层归一化和残差连接

### 层归一化

### 残差连接

#### 双残差连接

将输入特征分成两条路径分别进行处理，一直负责网络内部计算，一个负责保证梯度稳定。

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义Decoder-Only Transformer的层，融入标准的双残差ResiDual设计
class ResiDualTransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        """
        初始化ResiDualTransformerBlock模块。
        
        参数:
        - d_model: 模型的嵌入维度（hidden size）。
        - n_heads: 自注意力机制中的注意力头数。
        - d_ff: 前馈网络（FFN）的中间维度。
        - dropout: dropout率，用于防止过拟合。
        """
        super(ResiDualTransformerBlock, self).__init__()
        # 自注意力层（masked multi-head attention），对应f(Attn(·))
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        # 前馈网络（FFN），对应g(·)，通常是两个线性层加激活函数
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),  # 第一层线性变换
            nn.GELU(),                 # GELU激活函数
            nn.Linear(d_ff, d_model),  # 第二层线性变换
            nn.Dropout(dropout)        # dropout层
        )
        # 层归一化（Layer Normalization）模块，用于非线性变换
        self.ln1 = nn.LayerNorm(d_model)  # 第一次LN，用于输入归一化
        self.ln2 = nn.LayerNorm(d_model)  # 第二次LN，用于第一次残差后归一化
        self.ln3 = nn.LayerNorm(d_model)  # 第三次LN，用于第二次残差后归一化
        # dropout层，用于注意力输出
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        """
        前向传播函数，实现标准的双残差ResiDual流程。
        
        参数:
        - x: 输入张量，形状为(batch_size, seq_len, d_model)。
        - attn_mask: 可选的注意力掩码，用于Decoder-Only的因果注意力（causal mask）。
        
        流程:
        1. 对输入x进行层归一化，得到x_ln。
        2. 计算自注意力（masked），并与x_ln相加，形成第一次残差连接（y_local1）。
        3. 对y_local1进行层归一化，得到y_local1_ln。
        4. 通过FFN处理y_local1_ln，并与y_local1相加，形成第二次残差连接（y_local2）。
        5. 对y_local2进行层归一化，得到中间输出。
        6. 最后将原始输入x与中间输出相加，形成全局残差连接，得到最终输出y。
        """
        # 步骤1: 输入归一化
        x_ln = self.ln1(x)
        
        # 步骤2: 自注意力计算（f(Attn(x_ln))），注意要使用attn_mask实现Decoder-Only的因果注意力
        attn_output, _ = self.self_attn(x_ln, x_ln, x_ln, attn_mask=attn_mask)
        attn_output = self.dropout(attn_output)
        
        # 第一次残差连接
        y_local1 = x_ln + attn_output
        
        # 步骤3: 对y_local1进行归一化
        y_local1_ln = self.ln2(y_local1)
        
        # 步骤4: 前馈网络计算（g(FFN(y_local1_ln))）
        ffn_output = self.ffn(y_local1_ln)
        
        # 第二次残差连接
        y_local2 = y_local1 + ffn_output
        
        # 步骤5: 对y_local2进行归一化，得到中间输出
        intermediate = self.ln3(y_local2)
        
        # 步骤6: 全局残差连接: y = x + intermediate
        y = x + intermediate
        
        return y

# 示例使用: 创建一个ResiDualTransformerBlock实例，并进行前向传播测试
if __name__ == "__main__":
    # 定义参数
    batch_size = 2
    seq_len = 10
    d_model = 512
    n_heads = 8
    d_ff = 2048
    
    # 创建模块
    block = ResiDualTransformerBlock(d_model, n_heads, d_ff)
    
    # 生成随机输入
    x = torch.randn(batch_size, seq_len, d_model)
    
    # 生成因果注意力掩码（Decoder-Only）
    attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * float('-inf')
    
    # 前向传播
    output = block(x, attn_mask=attn_mask)
    
    # 打印输出形状以验证
    print("输出形状:", output.shape)  # 应为 (batch_size, seq_len, d_model)

#### 并行结构

将MSA子层和FFN子层并行放置，让它们同时处理输入。

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义一个带有并行结构的Transformer层
class ParallelTransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        """
        初始化ParallelTransformerBlock模块。
        
        参数:
        - d_model: 模型的嵌入维度（hidden size）。
        - n_heads: 自注意力机制中的注意力头数。
        - d_ff: 前馈网络（FFN）的中间维度。
        - dropout: dropout率，用于防止过拟合。
        """
        super(ParallelTransformerBlock, self).__init__()
        
        # 层归一化（Layer Normalization），用于并行分支的输入归一化
        self.ln = nn.LayerNorm(d_model)
        
        # 自注意力机制（Multi-Head Self-Attention）
        self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        
        # 前馈网络（FFN），包含两个线性层和激活函数
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),  # 第一层线性变换
            nn.GELU(),                 # GELU激活函数
            nn.Linear(d_ff, d_model),  # 第二层线性变换
            nn.Dropout(dropout)        # dropout层
        )
        
        # dropout层，用于并行输出的正则化
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attn_mask=None):
        """
        前向传播函数，实现并行自注意力和FFN结构。
        
        参数:
        - x: 输入张量，形状为(batch_size, seq_len, d_model)。
        - attn_mask: 可选的注意力掩码，用于Decoder-Only的因果注意力（causal mask）。
        
        流程:
        1. 对输入x进行层归一化，得到x_ln。
        2. 并行计算自注意力输出（attn_out）和FFN输出（ffn_out）。
        3. 将注意力输出和FFN输出与归一化后的输入x_ln相加。
        4. 应用dropout正则化。
        5. 全局残差连接，将结果与原始输入x相加。
        6. 返回最终输出。
        """
        # 步骤1: 输入归一化
        x_ln = self.ln(x)
        
        # 步骤2: 并行计算自注意力输出和FFN输出
        # 自注意力分支
        attn_out, _ = self.self_attn(x_ln, x_ln, x_ln, attn_mask=attn_mask)
        attn_out = self.dropout(attn_out)
        
        # FFN分支
        ffn_out = self.ffn(x_ln)
        ffn_out = self.dropout(ffn_out)
        
        # 步骤3: 并行融合，结合自注意力输出和FFN输出
        parallel_out = x_ln + attn_out + ffn_out
        
        # 步骤4: 应用dropout正则化
        parallel_out = self.dropout(parallel_out)
        
        # 步骤5: 全局残差连接
        y = x + parallel_out
        
        return y

# 示例使用: 创建一个ParallelTransformerBlock实例，并进行前向传播测试
if __name__ == "__main__":
    # 定义参数
    batch_size = 2
    seq_len = 10
    d_model = 512
    n_heads = 8
    d_ff = 2048
    
    # 创建模块
    block = ParallelTransformerBlock(d_model, n_heads, d_ff)
    
    # 生成随机输入
    x = torch.randn(batch_size, seq_len, d_model)
    
    # 生成因果注意力掩码（适用于Decoder-Only）
    attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * float('-inf')
    
    # 前向传播
    output = block(x, attn_mask=attn_mask)
    
    # 打印输出形状以验证
    print("输出形状:", output.shape)  # 应为 (batch_size, seq_len, d_model)

#### 随机深度

对于残差连接𝑦=𝑥+𝐹(𝑥)，在训练期间以一定概率随机丢弃残差分支中的𝐹(𝑥)。

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class StochasticDepth(nn.Module):
    def __init__(self, drop_prob: float):
        """
        drop_prob: 丢弃概率 (0 <= drop_prob < 1)
        """
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x, residual):
        """
        x: 残差主分支 (输入)
        residual: 子层输出 f(x)
        """
        if not self.training or self.drop_prob == 0.0:
            # 推理时（eval模式）或不丢弃 → 标准残差
            return x + residual

        # 训练时：生成 Bernoulli 掩码
        keep_prob = 1 - self.drop_prob
        mask = torch.empty(x.shape[0], 1, 1, device=x.device).bernoulli_(keep_prob)
        # 保持期望不变，需要缩放 residual
        residual = residual / keep_prob
        return x + residual * mask


# ===== 使用示例：Transformer Block =====
class TransformerBlock(nn.Module):
    def __init__(self, d_model, nhead, dim_ff, drop_prob=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(),
            nn.Linear(dim_ff, d_model)
        )
        # 在残差连接中应用随机深度
        self.sd1 = StochasticDepth(drop_prob)
        self.sd2 = StochasticDepth(drop_prob)

    def forward(self, x):
        # Self-Attention + 残差（带随机深度）
        attn_out, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x))
        x = self.sd1(x, attn_out)

        # FFN + 残差（带随机深度）
        ffn_out = self.ffn(self.ln2(x))
        x = self.sd2(x, ffn_out)

        return x


# ===== 测试运行 =====
if __name__ == "__main__":
    torch.manual_seed(42)

    batch_size, seq_len, d_model = 2, 10, 512
    x = torch.randn(batch_size, seq_len, d_model)

    block = TransformerBlock(d_model=512, nhead=4, dim_ff=64, drop_prob=0.2)
    block.train()  # 训练模式：随机丢弃

    out = block(x)
    print("Output shape (train):", out.shape)

    block.eval()  # 推理模式：不丢弃
    out_eval = block(x)
    print("Output shape (eval):", out_eval.shape)


## Transformer MoE模型示例

### 简单MoE模型

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# 多头注意力机制
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        QK = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            QK = QK.masked_fill(mask == 0, float('-inf'))
        
        attention_scores = F.softmax(QK, dim=-1)
        output = torch.matmul(attention_scores, V)
        return output, attention_scores
    
    def forward(self, x, mask=None):
        batch_size = x.size(0)
        
        Q = self.W_q(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        output, attention_scores = self.scaled_dot_product_attention(Q, K, V, mask)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(output)
        return output, attention_scores

# 混合专家系统 (MOE)
class MoEFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, num_experts, top_k=2):
        super(MoEFeedForward, self).__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # 门控网络
        self.gate = nn.Linear(d_model, num_experts)
        
        # 专家网络
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.ReLU(),
                nn.Linear(d_ff, d_model)
            ) for _ in range(num_experts)
        ])
        
    def forward(self, x):
        batch_size, seq_len, d_model = x.size()
        
        # 门控网络输出专家选择概率
        gate_scores = F.softmax(self.gate(x), dim=-1)  # [batch_size, seq_len, num_experts]
        
        # 选择 top-k 专家
        top_k_scores, top_k_indices = gate_scores.topk(self.top_k, dim=-1)  # [batch_size, seq_len, top_k]
        
        # 初始化输出
        output = torch.zeros_like(x)
        
        for i in range(self.top_k):
            expert_idx = top_k_indices[:, :, i]  # [batch_size, seq_len]
            expert_scores = top_k_scores[:, :, i].unsqueeze(-1)  # [batch_size, seq_len, 1]
            
            # 为每个样本选择对应的专家输出
            expert_output = torch.zeros_like(x)
            for j in range(self.num_experts):
                mask = (expert_idx == j).unsqueeze(-1).float()
                expert_output += mask * self.experts[j](x)
            
            output += expert_output * expert_scores
        
        return output

# Transformer Decoder 层
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, num_experts, dropout=0.1):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.moe = MoEFeedForward(d_model, d_ff, num_experts)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Pre-LayerNorm
        x_norm = self.norm1(x)
        attn_output, attn_scores = self.self_attention(x_norm, mask)
        x = x + self.dropout(attn_output)
        
        # Pre-LayerNorm for MoE
        x_norm = self.norm2(x)
        moe_output = self.moe(x_norm)
        x = x + self.dropout(moe_output)
        
        return x, attn_scores

# 完整的 Transformer Decoder-Only 模型
class TransformerDecoderOnly(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, num_experts, max_seq_len, dropout=0.1):
        super(TransformerDecoderOnly, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = self.create_positional_encoding(max_seq_len, d_model)
        self.layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff, num_experts, dropout)
            for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def create_positional_encoding(self, max_seq_len, d_model):
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_seq_len, d_model]
        return pe
    
    def create_causal_mask(self, seq_len):
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        return ~mask  # [seq_len, seq_len]
    
    def forward(self, x):
        batch_size, seq_len = x.size()
        mask = self.create_causal_mask(seq_len).to(x.device)
        
        # 嵌入和位置编码
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len, :].to(x.device)
        x = self.dropout(x)
        
        # 通过所有解码器层
        attention_scores = []
        for layer in self.layers:
            x, attn = layer(x, mask)
            attention_scores.append(attn)
        
        # 输出层
        output = self.fc_out(x)
        return output, attention_scores

# 示例使用
def main():
    # 模型参数
    vocab_size = 10000
    d_model = 512
    num_heads = 8
    d_ff = 2048
    num_layers = 6
    num_experts = 4
    max_seq_len = 100
    dropout = 0.1
    
    # 初始化模型
    model = TransformerDecoderOnly(
        vocab_size, d_model, num_heads, d_ff, num_layers, num_experts, max_seq_len, dropout
    )
    
    # 模拟输入
    batch_size = 32
    seq_len = 50
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
    
    # 前向传播
    output, attention_scores = model(input_ids)
    print(f"Output shape: {output.shape}")  # [batch_size, seq_len, vocab_size]
    print(f"Number of attention score tensors: {len(attention_scores)}")

if __name__ == "__main__":
    main()