Decoder 
一个Decoder Block包含三层结构：
Masked Self-Attention
Cross-Attention
Feed-Forward Network
其中为了防止过拟合，参数爆炸等问题，都加入LayerNorm Dropout Residual Connection
```
          Input (x)      ← shape: (batch, tgt_len, d_model)
               │
        ┌──────▼──────┐
        │  LayerNorm  │
        └──────┬──────┘
               ↓
     Masked Multi-Head Self-Attention
               ↓
          Dropout + Residual (+ x)
               │
        ┌──────▼──────┐
        │  LayerNorm  │
        └──────┬──────┘
               ↓
   Multi-Head Cross-Attention (with encoder output)
               ↓
          Dropout + Residual
               │
        ┌──────▼──────┐
        │  LayerNorm  │
        └──────┬──────┘
               ↓
       FeedForward Network (2-layer MLP)
               ↓
          Dropout + Residual
               │
               ▼
        Output (same shape as input)

```


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.MultiHeadAttentijon import MultiHeadAttention

In [8]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, dim_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)

        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_ff),
            nn.ReLU(),
            nn.Linear(dim_ff, d_model)
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, encoder_output, tgt_padding_mask=None, memory_padding_mask=None):
        """
        :param x: decoder input (batch_size, tgt_len, d_model)
        :param encoder_output: encoder output (batch_size, src_len, d_model)
        :param tgt_padding_mask: (batch_size, tgt_len)
        :param memory_padding_mask: (batch_size, src_len)
        """
        #  Masked Self-Attention (Q=K=V=x)
        x_norm = self.norm1(x)
        self_attn_out, _ = self.self_attn(
            x_norm,
            query=x_norm,
            key=x_norm,
            value=x_norm,
            padding_mask=tgt_padding_mask,
            causal_mask=True
        )
        x = x + self.dropout1(self_attn_out)

        # Cross Attention (Q = decoder, K/V = encoder)
        x_norm = self.norm2(x)
        cross_attn_out, _ = self.cross_attn(
            x_norm,
            query=x_norm,
            key=encoder_output,
            value=encoder_output,
            padding_mask=memory_padding_mask,
            causal_mask=False
        )
        x = x + self.dropout2(cross_attn_out)

        #  Feed Forward Network
        x_norm = self.norm3(x)
        ffn_out = self.ffn(x_norm)
        x = x + self.dropout3(ffn_out)

        return x  # shape: (batch_size, tgt_len, d_model)


Transformer Decoder

In [9]:
def test_decoder_block():
    batch_size = 2
    tgt_len = 6     # decoder 输入长度（如生成的 token 数）
    src_len = 8     # encoder 输出长度（源句子 token 数）
    d_model = 512
    num_heads = 8
    dim_ff = 2048

    # 输入数据：decoder input + encoder output
    decoder_input = torch.randn(batch_size, tgt_len, d_model)
    encoder_output = torch.randn(batch_size, src_len, d_model)

    # padding mask：1 表示 padding，需要被屏蔽
    tgt_padding_mask = torch.tensor([
        [0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 1]
    ])  # (batch_size, tgt_len)

    memory_padding_mask = torch.tensor([
        [0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1]
    ])  # (batch_size, src_len)

    # 初始化 DecoderBlock
    decoder_block = TransformerDecoderBlock(
        d_model=d_model,
        num_heads=num_heads,
        dim_ff=dim_ff,
        dropout=0.1
    )

    output = decoder_block(
        decoder_input,
        encoder_output,
        tgt_padding_mask=tgt_padding_mask,
        memory_padding_mask=memory_padding_mask
    )

    print("Decoder input shape:", decoder_input.shape)         # (2, 6, 512)
    print("Encoder output shape:", encoder_output.shape)       # (2, 8, 512)
    print("Decoder output shape:", output.shape)               # (2, 6, 512)
test_decoder_block()

Decoder input shape: torch.Size([2, 6, 512])
Encoder output shape: torch.Size([2, 8, 512])
Decoder output shape: torch.Size([2, 6, 512])


In [10]:
class TransformerDecoder(nn.Module):
    def __init__(self, d_model, num_heads, dim_ff, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerDecoderBlock(d_model, num_heads, dim_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)  # 最后加个输出归一化（论文原版也有）

    def forward(self, x, encoder_output, tgt_padding_mask=None, memory_padding_mask=None):
        for layer in self.layers:
            x = layer(
                x,
                encoder_output,
                tgt_padding_mask=tgt_padding_mask,
                memory_padding_mask=memory_padding_mask
            )
        return self.norm(x)  # (batch_size, tgt_len, d_model)


In [12]:
def test_transformer_decoder():
    batch_size = 2
    tgt_len = 6     # decoder 输入长度（目标序列）
    src_len = 8     # encoder 输出长度（源序列）
    d_model = 512
    num_heads = 8
    dim_ff = 2048
    num_layers = 4

    decoder_input = torch.randn(batch_size, tgt_len, d_model)
    encoder_output = torch.randn(batch_size, src_len, d_model)

    tgt_padding_mask = torch.tensor([
        [0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 1]
    ])  # shape: (2, 6)

    memory_padding_mask = torch.tensor([
        [0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1]
    ])  # shape: (2, 8)

    decoder = TransformerDecoder(
        d_model=d_model,
        num_heads=num_heads,
        dim_ff=dim_ff,
        num_layers=num_layers,
        dropout=0.1
    )

    output = decoder(
        decoder_input,
        encoder_output,
        tgt_padding_mask=tgt_padding_mask,
        memory_padding_mask=memory_padding_mask
    )

    print("Decoder input shape:   ", decoder_input.shape)    # (2, 6, 512)
    print("Encoder output shape:  ", encoder_output.shape)   # (2, 8, 512)
    print("Decoder output shape:  ", output.shape)           # (2, 6, 512)
test_transformer_decoder()

Decoder input shape:    torch.Size([2, 6, 512])
Encoder output shape:   torch.Size([2, 8, 512])
Decoder output shape:   torch.Size([2, 6, 512])
