In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

Transformer Encoder Block结构是：
```methamatica
Input → MultiHeadSelfAttention → Add & Norm
      → FeedForward → Add & Norm
```

原版实现使用的post-LN，但是有梯度消失、爆炸的隐患
下面实现Pre-LN的结构，也就是LayerNorm放在其他子层输入之前

In [2]:
from transformers.MultiHeadAttentijon import MultiHeadAttention

class TransformerEncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, dim_ff, dropout=0.1):
        super().__init__()

        self.attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)

        self.ffn = nn.Sequential(
            nn.Linear(d_model, dim_ff),   # Linear@x.shape=(batch_size, seq_len, d_model) → (batch_size, seq_len, dim_ff)
            nn.ReLU(),
            nn.Linear(dim_ff, d_model)    # Linear@x.shape=(batch_size, seq_len, dim_ff) → (batch_size, seq_len, d_model)
        )

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, padding_mask=None):
        # [size] x: (batch_size, seq_len, d_model)

        # Pre-LN + Attention + Dropout + Residual
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(
            x_norm,
            padding_mask=padding_mask,
            causal_mask=False
        )
        x = x + self.dropout1(attn_out)

        # Pre-LN + FFN + Dropout + Residual
        x_ffn = self.norm2(x)
        ffn_out = self.ffn(x_ffn)
        x = x + self.dropout2(ffn_out)

        return x  # (batch_size, seq_len, d_model)

结构长这样
```
      Input (x)
         │
    ┌────┴────┐
    ↓         ↓
 Self-Attn   Shortcut
    ↓         ↓
 Dropout     │
    ↓        │
    Add ◄────┘
     ↓  ────────       
 LayerNorm      │
     ↓          │
 FeedForward    │
     ↓          │
 Dropout        │   
     ↓          │
 Add ◄──────────  
     ↓     
 LayerNorm 
      ↓    
   Output  


```

In [3]:
def test_encoder_block_pre_ln():
    batch_size = 2
    seq_len = 6
    d_model = 512
    num_heads = 8
    dim_ff = 2048

    # 模拟输入
    x = torch.randn(batch_size, seq_len, d_model)

    # 模拟 padding mask（batch 中第 2 个序列有 3 个 padding）
    padding_mask = torch.tensor([
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1]
    ])  # shape: (batch_size, seq_len)

    # 创建模块并测试
    block = TransformerEncoderBlock(d_model=d_model, num_heads=num_heads, dim_ff=dim_ff)
    out = block(x, padding_mask=padding_mask)

    print("Input shape:", x.shape)       # (2, 6, 512)
    print("Output shape:", out.shape)    # (2, 6, 512)

test_encoder_block_pre_ln()

Input shape: torch.Size([2, 6, 512])
Output shape: torch.Size([2, 6, 512])


In [4]:
class TransformerEncoder(nn.Module):
    def __init__(self, d_model, num_heads, dim_ff, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, num_heads, dim_ff, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x, padding_mask=None):
        for layer in self.layers:
            x = layer(x, padding_mask=padding_mask)
        return x


In [5]:
def test_transformer_encoder():
    batch_size = 2
    seq_len = 6
    d_model = 512
    num_heads = 8
    dim_ff = 2048
    num_layers = 3

    x = torch.randn(batch_size, seq_len, d_model)
    padding_mask = torch.tensor([
        [0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 1, 1]
    ])

    encoder = TransformerEncoder(d_model, num_heads, dim_ff, num_layers)
    out = encoder(x, padding_mask=padding_mask)

    print("Input shape:", x.shape)
    print("Output shape:", out.shape)
test_transformer_encoder()

Input shape: torch.Size([2, 6, 512])
Output shape: torch.Size([2, 6, 512])


和input embedding层混合在一起，形成一个完整的encoder

In [6]:
from InputEmbedding import InputEmbedding
class TransformerEncoderModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, dim_ff, num_layers, max_len=5000, dropout=0.1):
        super().__init__()
        self.input_embedding = InputEmbedding(vocab_size, d_model, max_len)
        self.encoder = TransformerEncoder(d_model, num_heads, dim_ff, num_layers, dropout)

    def forward(self, input_ids, padding_mask=None):
        # [size] input_ids: (batch_size, seq_len)
        x = self.input_embedding(input_ids)  # (batch_size, seq_len, d_model)
        x = self.encoder(x, padding_mask=padding_mask)  # (batch_size, seq_len, d_model)
        return x


In [7]:
if __name__ == '__main__':
    vocab_size = 1000
    d_model = 512
    num_heads = 8
    dim_ff = 2048
    num_layers = 6
    seq_len = 10
    batch_size = 2

    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
    padding_mask = torch.tensor([
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],    # No padding
        [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]     # Padding after 4th token
    ])

    model = TransformerEncoderModel(
        vocab_size=vocab_size,
        d_model=d_model,
        num_heads=num_heads,
        dim_ff=dim_ff,
        num_layers=num_layers,
        max_len=5000,
        dropout=0.1
    )

    out = model(input_ids, padding_mask=padding_mask)
    print("Input IDs shape:", input_ids.shape)   # (2, 10)
    print("Encoder output shape:", out.shape)     # (2, 10, 512)


Input IDs shape: torch.Size([2, 10])
Encoder output shape: torch.Size([2, 10, 512])
