# Transformer Decoder 实现

本 Notebook 实现 Transformer 的 **Decoder** 部分，包括：
- `DecoderLayer`：单层解码器（Masked Self-Attention + Cross-Attention + FFN）
- `Decoder`：完整解码器（Embedding + PE + N 层 DecoderLayer + Linear）

---

## 1. 理论核心

### 1.1 DecoderLayer vs EncoderLayer

| 组件 | EncoderLayer | DecoderLayer |
|:-----|:-------------|:-------------|
| Self-Attention | 1 个 (无因果掩码) | 1 个 (**带因果掩码**) |
| Cross-Attention | **无** | **1 个** (Encoder-Decoder Attention) |
| Feed Forward | 1 个 | 1 个 |
| Add & Norm | 2 个 | **3 个** |

**为什么 Decoder 需要多一个 Attention 层？**

- **Masked Self-Attention**：让 Decoder 关注**已生成的 token**（自回归）
- **Cross-Attention**：让 Decoder 关注 **Encoder 的输出**（源序列信息）

```
DecoderLayer 结构：

    Target Input
         |
         v
  +-----------------+
  | Masked Self-Attn|  <-- Q, K, V 都来自 target
  +-----------------+
         |
    Add & Norm
         |
         v
  +-----------------+
  | Cross-Attention |  <-- Q 来自 Decoder, K/V 来自 Encoder (Memory)
  +-----------------+
         |
    Add & Norm
         |
         v
  +-----------------+
  |  Feed Forward   |
  +-----------------+
         |
    Add & Norm
         |
         v
      Output
```

### 1.2 Look-Ahead Mask (因果掩码)

**为什么需要因果掩码？**

在训练时，Decoder 接收完整的目标序列作为输入。但在预测位置 $t$ 时，模型**不应该看到位置 $t+1, t+2, ...$ 的信息**，否则就是"作弊"。

**解决方案**：使用上三角掩码，将未来位置的注意力分数设为 $-\infty$。

**上三角掩码示意图** (seq_len=5)：

```
Query\Key    t=0   t=1   t=2   t=3   t=4
  t=0        0    -inf  -inf  -inf  -inf
  t=1        0     0    -inf  -inf  -inf
  t=2        0     0     0    -inf  -inf
  t=3        0     0     0     0    -inf
  t=4        0     0     0     0     0

0    = 可以关注 (softmax 后有权重)
-inf = 不可关注 (softmax 后为 0)
```

**效果**：位置 $t$ 只能关注位置 $0, 1, ..., t$，无法"偷看"未来。

---

## 2. 导入依赖

In [None]:
import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

---

## 3. 辅助组件

In [None]:
class MultiHeadAttention(nn.Module):
    """多头注意力机制。"""

    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1) -> None:
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.scale = 1.0 / math.sqrt(self.d_k)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = query.size(0)

        q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        if mask is not None:
            scores = scores.masked_fill(mask == 1, float("-inf"))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context = torch.matmul(attn_weights, v)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(context)

        return output, attn_weights

In [None]:
class PositionalEncoding(nn.Module):
    """位置编码（正弦/余弦）。"""

    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.pe[:, : x.size(1), :]
        return self.dropout(x)

In [None]:
class PositionwiseFeedForward(nn.Module):
    """位置前馈网络：Linear -> ReLU -> Linear。"""

    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1) -> None:
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

---

## 4. Mask 生成函数

In [None]:
def generate_square_subsequent_mask(seq_len: int) -> torch.Tensor:
    """生成因果掩码（上三角矩阵）。

    Args:
        seq_len: 序列长度。

    Returns:
        mask: (seq_len, seq_len)，上三角为 1，其余为 0。
    """
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    return mask

In [None]:
# 可视化因果掩码
seq_len = 5
causal_mask = generate_square_subsequent_mask(seq_len)
print(f"因果掩码 (seq_len={seq_len}):")
print(causal_mask)
print("\n1 = 被遮蔽 (不可关注), 0 = 可关注")

---

## 5. DecoderLayer 实现

In [None]:
class DecoderLayer(nn.Module):
    """Transformer Decoder 单层。

    结构：Masked Self-Attn -> Add & Norm -> Cross-Attn -> Add & Norm -> FFN -> Add & Norm

    Args:
        d_model: 模型维度。
        n_heads: 注意力头数。
        d_ff: FFN 隐藏层维度。
        dropout: Dropout 概率。
    """

    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1) -> None:
        super().__init__()
        # Sub-layer 1: Masked Self-Attention
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        # Sub-layer 2: Cross-Attention (Encoder-Decoder Attention)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

        # Sub-layer 3: Feed Forward
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout3 = nn.Dropout(dropout)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """前向传播。

        Args:
            tgt: 目标序列 (batch, tgt_seq, d_model)。
            memory: Encoder 输出 (batch, src_seq, d_model)。
            tgt_mask: 因果掩码 (tgt_seq, tgt_seq) 或 (batch, 1, tgt_seq, tgt_seq)。
            memory_mask: Encoder padding mask (batch, 1, 1, src_seq)。

        Returns:
            输出张量 (batch, tgt_seq, d_model)。
        """
        # Sub-layer 1: Masked Self-Attention + Add & Norm
        self_attn_output, _ = self.self_attn(tgt, tgt, tgt, tgt_mask)
        tgt = self.norm1(tgt + self.dropout1(self_attn_output))

        # Sub-layer 2: Cross-Attention + Add & Norm
        # Q 来自 Decoder (tgt), K/V 来自 Encoder (memory)
        cross_attn_output, _ = self.cross_attn(tgt, memory, memory, memory_mask)
        tgt = self.norm2(tgt + self.dropout2(cross_attn_output))

        # Sub-layer 3: Feed Forward + Add & Norm
        ffn_output = self.ffn(tgt)
        tgt = self.norm3(tgt + self.dropout3(ffn_output))

        return tgt

---

## 6. 完整 Decoder 组装

In [None]:
class Decoder(nn.Module):
    """Transformer Decoder。

    结构：Embedding -> Positional Encoding -> N x DecoderLayer -> Linear

    Args:
        vocab_size: 词表大小。
        d_model: 模型维度。
        n_layers: DecoderLayer 层数。
        n_heads: 注意力头数。
        d_ff: FFN 隐藏层维度。
        max_len: 最大序列长度。
        dropout: Dropout 概率。
        pad_idx: Padding token 的索引。
    """

    def __init__(
        self,
        vocab_size: int,
        d_model: int,
        n_layers: int,
        n_heads: int,
        d_ff: int,
        max_len: int = 5000,
        dropout: float = 0.1,
        pad_idx: int = 0,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.pad_idx = pad_idx

        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)

        self.layers = nn.ModuleList(
            [DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
        )

        # 输出投影到词表大小
        self.output_projection = nn.Linear(d_model, vocab_size)

        self.scale = math.sqrt(d_model)

    def create_causal_mask(self, tgt: torch.Tensor) -> torch.Tensor:
        """创建因果掩码。

        Args:
            tgt: 目标 token IDs (batch, seq)。

        Returns:
            mask: (1, 1, seq, seq)，上三角为 1。
        """
        seq_len = tgt.size(1)
        mask = generate_square_subsequent_mask(seq_len)
        return mask.unsqueeze(0).unsqueeze(0).to(tgt.device)

    def create_padding_mask(self, tgt: torch.Tensor) -> torch.Tensor:
        """创建 Padding Mask。"""
        mask = (tgt == self.pad_idx).unsqueeze(1).unsqueeze(2)
        return mask.to(tgt.device)

    def forward(
        self,
        tgt: torch.Tensor,
        memory: torch.Tensor,
        tgt_mask: Optional[torch.Tensor] = None,
        memory_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """前向传播。

        Args:
            tgt: 目标 token IDs (batch, tgt_seq)。
            memory: Encoder 输出 (batch, src_seq, d_model)。
            tgt_mask: 可选的因果掩码，若为 None 则自动生成。
            memory_mask: 可选的 Encoder padding mask。

        Returns:
            logits: (batch, tgt_seq, vocab_size)。
        """
        if tgt_mask is None:
            tgt_mask = self.create_causal_mask(tgt)

        # Embedding + Scale + Positional Encoding
        x = self.embedding(tgt) * self.scale
        x = self.pos_encoding(x)

        # Pass through N decoder layers
        for layer in self.layers:
            x = layer(x, memory, tgt_mask, memory_mask)

        # Project to vocabulary size
        logits = self.output_projection(x)

        return logits

---

## 7. 完整性测试

In [None]:
torch.manual_seed(42)

# 超参数
vocab_size = 10000
d_model = 512
n_layers = 6
n_heads = 8
d_ff = 2048
batch_size = 2
src_seq_len = 12
tgt_seq_len = 10

# 实例化 Decoder
decoder = Decoder(
    vocab_size=vocab_size,
    d_model=d_model,
    n_layers=n_layers,
    n_heads=n_heads,
    d_ff=d_ff,
)

print(f"Decoder 参数量: {sum(p.numel() for p in decoder.parameters()):,}")

In [None]:
# 构造输入
tgt = torch.randint(1, vocab_size, (batch_size, tgt_seq_len))
memory = torch.randn(batch_size, src_seq_len, d_model)  # 模拟 Encoder 输出

print(f"目标序列 tgt 形状: {tgt.shape}")
print(f"Encoder 输出 memory 形状: {memory.shape}")

In [None]:
# 前向传播
decoder.eval()
with torch.no_grad():
    logits = decoder(tgt, memory)

print(f"\n输出 logits 形状: {logits.shape}")
print(f"期望形状: (batch={batch_size}, tgt_seq={tgt_seq_len}, vocab_size={vocab_size})")

In [None]:
# 断言验证
assert logits.shape == (batch_size, tgt_seq_len, vocab_size), "输出形状不正确！"
print("\n[PASS] 输出形状验证通过: (batch, tgt_seq, vocab_size)")

# 验证因果掩码
causal_mask = decoder.create_causal_mask(tgt)
print(f"\n因果掩码形状: {causal_mask.shape}")
print(f"因果掩码示例 (squeeze):")
print(causal_mask[0, 0])

---

## 8. 总结

| 组件 | 作用 |
|:-----|:-----|
| `Masked Self-Attention` | 自回归：只关注已生成的 token |
| `Cross-Attention` | 关注 Encoder 输出 (Q 来自 Decoder, K/V 来自 Encoder) |
| `因果掩码` | 上三角矩阵，防止"偷看"未来 |
| `Output Projection` | Linear 层映射到词表大小 |

**Decoder 与 Encoder 的关键区别**：
1. 多一个 Cross-Attention 层
2. Self-Attention 需要因果掩码
3. 最后有 Linear 层输出 logits

**下一步**：将 Encoder 和 Decoder 组合成完整的 Transformer 模型。