# 6. 完整Transformer实现

在前面的教程中，我们学习了Transformer的各个组件。现在是时候将它们组合成完整的Transformer模型，并实现一个端到端的机器翻译任务。

## 6.1 完整Transformer架构

原始的Transformer包含编码器（Encoder）和解码器（Decoder）两部分：

### 编码器（Encoder）
- 输入：源序列 + 位置编码
- 结构：N层编码器块的堆叠
- 输出：编码后的表示

### 解码器（Decoder）
- 输入：目标序列 + 位置编码 + 编码器输出
- 结构：N层解码器块的堆叠
- 输出：生成的序列概率分布

![完整Transformer架构](images/complete_transformer.png)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
import copy
from typing import Optional, Tuple, List

# 从utils导入基础组件
from utils import (
    MultiHeadAttention, 
    SinusoidalPositionalEncoding, 
    FeedForwardNetwork,
    create_padding_mask,
    create_causal_mask,
    setup_matplotlib_chinese,
    set_random_seed
)

# 设置环境
setup_matplotlib_chinese()
set_random_seed(42)

print(f"PyTorch版本: {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

## 6.2 解码器块（Decoder Block）

解码器块与编码器块类似，但包含三个子层：
1. **掩码多头自注意力**：防止看到未来的token
2. **编码器-解码器注意力**：关注编码器的输出
3. **前馈网络**：非线性变换

In [None]:
class TransformerDecoderBlock(nn.Module):
    """
    Transformer解码器块
    """
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super(TransformerDecoderBlock, self).__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        
        # 三个注意力层
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)  # 掩码自注意力
        self.cross_attention = MultiHeadAttention(d_model, num_heads, dropout)  # 编码器-解码器注意力
        
        # 前馈网络
        self.feed_forward = FeedForwardNetwork(d_model, d_ff, dropout)
        
        # 层归一化
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor,
                src_mask: Optional[torch.Tensor] = None,
                tgt_mask: Optional[torch.Tensor] = None):
        """
        前向传播
        
        Args:
            x: 解码器输入 [batch_size, tgt_len, d_model]
            encoder_output: 编码器输出 [batch_size, src_len, d_model]
            src_mask: 源序列掩码 [batch_size, src_len, src_len]
            tgt_mask: 目标序列掩码 [batch_size, tgt_len, tgt_len]
        
        Returns:
            output: [batch_size, tgt_len, d_model]
            self_attn_weights: 自注意力权重
            cross_attn_weights: 交叉注意力权重
        """
        # 第一个子层：掩码多头自注意力
        self_attn_output, self_attn_weights = self.self_attention(x, x, x, tgt_mask)
        x1 = self.norm1(x + self.dropout(self_attn_output))
        
        # 第二个子层：编码器-解码器注意力
        cross_attn_output, cross_attn_weights = self.cross_attention(
            x1, encoder_output, encoder_output, src_mask
        )
        x2 = self.norm2(x1 + self.dropout(cross_attn_output))
        
        # 第三个子层：前馈网络
        ff_output = self.feed_forward(x2)
        x3 = self.norm3(x2 + self.dropout(ff_output))
        
        return x3, self_attn_weights, cross_attn_weights

# 测试解码器块
d_model = 256
num_heads = 8
d_ff = 1024
tgt_len = 10
src_len = 12
batch_size = 2

decoder_block = TransformerDecoderBlock(d_model, num_heads, d_ff)

# 创建测试输入
decoder_input = torch.randn(batch_size, tgt_len, d_model)
encoder_output = torch.randn(batch_size, src_len, d_model)

# 创建掩码
tgt_mask = create_causal_mask(tgt_len).unsqueeze(0).expand(batch_size, -1, -1)

# 前向传播
output, self_attn, cross_attn = decoder_block(
    decoder_input, encoder_output, tgt_mask=tgt_mask
)

print(f"解码器块输入形状: {decoder_input.shape}")
print(f"编码器输出形状: {encoder_output.shape}")
print(f"解码器块输出形状: {output.shape}")
print(f"自注意力权重形状: {self_attn.shape}")
print(f"交叉注意力权重形状: {cross_attn.shape}")

## 6.3 完整Transformer模型

现在我们将编码器和解码器组合成完整的Transformer模型：

In [None]:
class TransformerEncoder(nn.Module):
    """
    Transformer编码器
    """
    def __init__(self, vocab_size: int, d_model: int, num_heads: int, 
                 d_ff: int, num_layers: int, max_seq_len: int = 5000, 
                 dropout: float = 0.1):
        super(TransformerEncoder, self).__init__()
        
        self.d_model = d_model
        
        # 词嵌入和位置编码
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len)
        
        # 编码器层
        from utils import TransformerBlock
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None):
        # 嵌入和位置编码
        x = self.embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # 通过编码器层
        attention_weights = []
        for layer in self.layers:
            x, attn_weights = layer(x, src_mask)
            attention_weights.append(attn_weights)
        
        return x, attention_weights


class TransformerDecoder(nn.Module):
    """
    Transformer解码器
    """
    def __init__(self, vocab_size: int, d_model: int, num_heads: int, 
                 d_ff: int, num_layers: int, max_seq_len: int = 5000, 
                 dropout: float = 0.1):
        super(TransformerDecoder, self).__init__()
        
        self.d_model = d_model
        
        # 词嵌入和位置编码
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len)
        
        # 解码器层
        self.layers = nn.ModuleList([
            TransformerDecoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, tgt: torch.Tensor, encoder_output: torch.Tensor,
                src_mask: Optional[torch.Tensor] = None,
                tgt_mask: Optional[torch.Tensor] = None):
        # 嵌入和位置编码
        x = self.embedding(tgt) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # 通过解码器层
        self_attention_weights = []
        cross_attention_weights = []
        
        for layer in self.layers:
            x, self_attn, cross_attn = layer(x, encoder_output, src_mask, tgt_mask)
            self_attention_weights.append(self_attn)
            cross_attention_weights.append(cross_attn)
        
        return x, self_attention_weights, cross_attention_weights


class Transformer(nn.Module):
    """
    完整的Transformer模型
    """
    def __init__(self, src_vocab_size: int, tgt_vocab_size: int,
                 d_model: int = 512, num_heads: int = 8, 
                 d_ff: int = 2048, num_layers: int = 6,
                 max_seq_len: int = 5000, dropout: float = 0.1):
        super(Transformer, self).__init__()
        
        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = tgt_vocab_size
        self.d_model = d_model
        
        # 编码器和解码器
        self.encoder = TransformerEncoder(
            src_vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_len, dropout
        )
        self.decoder = TransformerDecoder(
            tgt_vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_len, dropout
        )
        
        # 输出投影层
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        
        # 权重初始化
        self._init_weights()
        
    def _init_weights(self):
        """初始化权重"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0, std=0.1)
    
    def forward(self, src: torch.Tensor, tgt: torch.Tensor,
                src_mask: Optional[torch.Tensor] = None,
                tgt_mask: Optional[torch.Tensor] = None):
        """
        前向传播
        
        Args:
            src: 源序列 [batch_size, src_len]
            tgt: 目标序列 [batch_size, tgt_len]
            src_mask: 源序列掩码
            tgt_mask: 目标序列掩码
        
        Returns:
            output: [batch_size, tgt_len, tgt_vocab_size]
            attention_weights: 注意力权重字典
        """
        # 编码器
        encoder_output, encoder_attention = self.encoder(src, src_mask)
        
        # 解码器
        decoder_output, decoder_self_attention, decoder_cross_attention = self.decoder(
            tgt, encoder_output, src_mask, tgt_mask
        )
        
        # 输出投影
        output = self.output_projection(decoder_output)
        
        attention_weights = {
            'encoder_attention': encoder_attention,
            'decoder_self_attention': decoder_self_attention,
            'decoder_cross_attention': decoder_cross_attention
        }
        
        return output, attention_weights

# 创建完整的Transformer模型
src_vocab_size = 1000
tgt_vocab_size = 1000
d_model = 256
num_heads = 8
d_ff = 1024
num_layers = 4

transformer = Transformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=d_model,
    num_heads=num_heads,
    d_ff=d_ff,
    num_layers=num_layers
)

# 计算参数量
total_params = sum(p.numel() for p in transformer.parameters())
trainable_params = sum(p.numel() for p in transformer.parameters() if p.requires_grad)

print(f"Transformer模型参数总量: {total_params:,}")
print(f"可训练参数量: {trainable_params:,}")
print(f"模型大小: {total_params * 4 / (1024**2):.2f} MB")

## 6.4 机器翻译任务实战

让我们用完整的Transformer模型来实现一个简单的机器翻译任务：

In [None]:
class SimpleTokenizer:
    """
    简单的分词器，用于演示
    """
    def __init__(self):
        self.pad_token = '<PAD>'
        self.sos_token = '<SOS>'
        self.eos_token = '<EOS>'
        self.unk_token = '<UNK>'
        
        self.special_tokens = [self.pad_token, self.sos_token, self.eos_token, self.unk_token]
        self.vocab = {}
        self.idx_to_token = {}
        
    def build_vocab(self, sentences: List[str]):
        """构建词汇表"""
        # 添加特殊token
        for i, token in enumerate(self.special_tokens):
            self.vocab[token] = i
            self.idx_to_token[i] = token
        
        # 添加词汇
        word_count = {}
        for sentence in sentences:
            for word in sentence.lower().split():
                word_count[word] = word_count.get(word, 0) + 1
        
        # 按频率排序并添加到词汇表
        sorted_words = sorted(word_count.items(), key=lambda x: x[1], reverse=True)
        
        for word, _ in sorted_words:
            if word not in self.vocab:
                idx = len(self.vocab)
                self.vocab[word] = idx
                self.idx_to_token[idx] = word
    
    def encode(self, sentence: str, max_len: int = None) -> List[int]:
        """编码句子"""
        tokens = [self.vocab.get(word.lower(), self.vocab[self.unk_token]) 
                 for word in sentence.split()]
        
        # 添加SOS和EOS
        tokens = [self.vocab[self.sos_token]] + tokens + [self.vocab[self.eos_token]]
        
        # 截断或填充
        if max_len:
            if len(tokens) > max_len:
                tokens = tokens[:max_len]
            else:
                tokens += [self.vocab[self.pad_token]] * (max_len - len(tokens))
        
        return tokens
    
    def decode(self, tokens: List[int]) -> str:
        """解码token序列"""
        words = []
        for token in tokens:
            word = self.idx_to_token.get(token, self.unk_token)
            if word == self.eos_token:
                break
            if word not in self.special_tokens:
                words.append(word)
        return ' '.join(words)
    
    @property
    def vocab_size(self):
        return len(self.vocab)
    
    @property
    def pad_token_id(self):
        return self.vocab[self.pad_token]
    
    @property
    def sos_token_id(self):
        return self.vocab[self.sos_token]
    
    @property
    def eos_token_id(self):
        return self.vocab[self.eos_token]

# 创建示例数据集（英语->中文的简单例子）
def create_toy_dataset():
    """创建玩具数据集"""
    en_sentences = [
        "hello world",
        "how are you",
        "good morning",
        "thank you",
        "goodbye",
        "i love you",
        "what is your name",
        "nice to meet you",
        "have a good day",
        "see you later"
    ]
    
    zh_sentences = [
        "你好 世界",
        "你 好 吗",
        "早上 好",
        "谢谢 你",
        "再见",
        "我 爱 你",
        "你 的 名字 是 什么",
        "很 高兴 见到 你",
        "祝 你 有 美好 的 一天",
        "回头 见"
    ]
    
    return en_sentences, zh_sentences

# 准备数据
en_sentences, zh_sentences = create_toy_dataset()

# 创建分词器
en_tokenizer = SimpleTokenizer()
zh_tokenizer = SimpleTokenizer()

en_tokenizer.build_vocab(en_sentences)
zh_tokenizer.build_vocab(zh_sentences)

print(f"英语词汇表大小: {en_tokenizer.vocab_size}")
print(f"中文词汇表大小: {zh_tokenizer.vocab_size}")

# 编码示例
example_en = "hello world"
example_zh = "你好 世界"

en_encoded = en_tokenizer.encode(example_en, max_len=10)
zh_encoded = zh_tokenizer.encode(example_zh, max_len=10)

print(f"\n编码示例:")
print(f"英语: '{example_en}' -> {en_encoded}")
print(f"中文: '{example_zh}' -> {zh_encoded}")
print(f"解码: {en_tokenizer.decode(en_encoded)}")
print(f"解码: {zh_tokenizer.decode(zh_encoded)}")

## 6.5 训练过程实现

现在让我们实现完整的训练过程：

In [None]:
class TranslationDataset(torch.utils.data.Dataset):
    """
    翻译数据集
    """
    def __init__(self, src_sentences, tgt_sentences, src_tokenizer, tgt_tokenizer, max_len=20):
        self.src_sentences = src_sentences
        self.tgt_sentences = tgt_sentences
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.src_sentences)
    
    def __getitem__(self, idx):
        src_encoded = self.src_tokenizer.encode(self.src_sentences[idx], self.max_len)
        tgt_encoded = self.tgt_tokenizer.encode(self.tgt_sentences[idx], self.max_len)
        
        return {
            'src': torch.tensor(src_encoded, dtype=torch.long),
            'tgt': torch.tensor(tgt_encoded, dtype=torch.long)
        }

def create_masks(src, tgt, src_pad_idx, tgt_pad_idx):
    """
    创建训练所需的掩码
    """
    # 源序列填充掩码
    src_mask = (src != src_pad_idx).unsqueeze(1).unsqueeze(2)
    
    # 目标序列填充掩码和因果掩码
    tgt_len = tgt.size(1)
    tgt_pad_mask = (tgt != tgt_pad_idx).unsqueeze(1).unsqueeze(2)
    tgt_causal_mask = torch.tril(torch.ones(tgt_len, tgt_len, device=tgt.device))
    tgt_mask = tgt_pad_mask & tgt_causal_mask
    
    return src_mask, tgt_mask

def train_step(model, batch, criterion, optimizer, src_tokenizer, tgt_tokenizer):
    """
    单步训练
    """
    model.train()
    
    src = batch['src']
    tgt = batch['tgt']
    
    # 目标序列分为输入和标签
    tgt_input = tgt[:, :-1]  # 去掉最后一个token作为输入
    tgt_label = tgt[:, 1:]   # 去掉第一个token作为标签
    
    # 创建掩码
    src_mask, tgt_mask = create_masks(
        src, tgt_input, src_tokenizer.pad_token_id, tgt_tokenizer.pad_token_id
    )
    
    # 前向传播
    output, _ = model(src, tgt_input, src_mask, tgt_mask)
    
    # 计算损失
    loss = criterion(output.reshape(-1, output.size(-1)), tgt_label.reshape(-1))
    
    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    
    # 梯度裁剪
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    
    return loss.item()

# 创建数据集和数据加载器
dataset = TranslationDataset(en_sentences, zh_sentences, en_tokenizer, zh_tokenizer, max_len=15)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

# 创建模型
model = Transformer(
    src_vocab_size=en_tokenizer.vocab_size,
    tgt_vocab_size=zh_tokenizer.vocab_size,
    d_model=128,
    num_heads=4,
    d_ff=512,
    num_layers=2,
    dropout=0.1
)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=zh_tokenizer.pad_token_id)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)

# 训练循环
num_epochs = 50
losses = []

print("开始训练...")
for epoch in range(num_epochs):
    epoch_losses = []
    
    for batch in dataloader:
        loss = train_step(model, batch, criterion, optimizer, en_tokenizer, zh_tokenizer)
        epoch_losses.append(loss)
    
    avg_loss = np.mean(epoch_losses)
    losses.append(avg_loss)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

print("训练完成！")

## 6.6 推理和翻译

训练完成后，让我们实现翻译功能：

In [None]:
def translate(model, src_sentence, src_tokenizer, tgt_tokenizer, max_len=20):
    """
    翻译函数
    """
    model.eval()
    
    with torch.no_grad():
        # 编码源句子
        src_tokens = src_tokenizer.encode(src_sentence, max_len)
        src = torch.tensor(src_tokens).unsqueeze(0)  # 添加batch维度
        
        # 编码器前向传播
        src_mask = (src != src_tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2)
        encoder_output, _ = model.encoder(src, src_mask)
        
        # 初始化解码序列
        tgt_tokens = [tgt_tokenizer.sos_token_id]
        
        # 逐步生成
        for _ in range(max_len - 1):
            tgt = torch.tensor(tgt_tokens).unsqueeze(0)
            tgt_len = tgt.size(1)
            
            # 创建目标掩码
            tgt_mask = torch.tril(torch.ones(tgt_len, tgt_len))
            tgt_mask = tgt_mask.unsqueeze(0)
            
            # 解码器前向传播
            decoder_output, _, _ = model.decoder(tgt, encoder_output, src_mask, tgt_mask)
            
            # 输出投影
            output = model.output_projection(decoder_output)
            
            # 获取下一个token
            next_token = output[0, -1, :].argmax().item()
            tgt_tokens.append(next_token)
            
            # 如果生成了结束token，停止生成
            if next_token == tgt_tokenizer.eos_token_id:
                break
    
    # 解码生成的序列
    translated = tgt_tokenizer.decode(tgt_tokens[1:])  # 去掉SOS token
    return translated

# 测试翻译
test_sentences = [
    "hello world",
    "good morning",
    "thank you",
    "i love you"
]

print("翻译测试结果:")
print("=" * 40)

for en_sentence in test_sentences:
    translated = translate(model, en_sentence, en_tokenizer, zh_tokenizer)
    print(f"英语: {en_sentence}")
    print(f"翻译: {translated}")
    print("-" * 20)

## 6.7 可视化分析

让我们分析训练过程和注意力权重：

In [None]:
# 绘制训练损失曲线
plt.figure(figsize=(10, 6))
plt.plot(losses, 'b-', linewidth=2)
plt.title('训练损失曲线')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.grid(True, alpha=0.3)
plt.show()

# 可视化注意力权重
def visualize_translation_attention(model, src_sentence, tgt_sentence, 
                                   src_tokenizer, tgt_tokenizer):
    """
    可视化翻译过程中的注意力权重
    """
    model.eval()
    
    with torch.no_grad():
        # 编码输入
        src_tokens = src_tokenizer.encode(src_sentence, max_len=15)
        tgt_tokens = tgt_tokenizer.encode(tgt_sentence, max_len=15)
        
        src = torch.tensor(src_tokens).unsqueeze(0)
        tgt_input = torch.tensor(tgt_tokens[:-1]).unsqueeze(0)  # 去掉EOS
        
        # 创建掩码
        src_mask, tgt_mask = create_masks(
            src, tgt_input, src_tokenizer.pad_token_id, tgt_tokenizer.pad_token_id
        )
        
        # 前向传播
        output, attention_weights = model(src, tgt_input, src_mask, tgt_mask)
        
        # 获取cross attention权重（最后一层）
        cross_attention = attention_weights['decoder_cross_attention'][-1][0, 0]  # 第一个头
        
        # 获取有效的token
        src_words = [src_tokenizer.idx_to_token[idx] for idx in src_tokens 
                    if idx != src_tokenizer.pad_token_id]
        tgt_words = [tgt_tokenizer.idx_to_token[idx] for idx in tgt_tokens[:-1] 
                    if idx != tgt_tokenizer.pad_token_id]
        
        # 截取有效的注意力权重
        attention_matrix = cross_attention[:len(tgt_words), :len(src_words)].cpu().numpy()
        
        # 可视化
        plt.figure(figsize=(10, 8))
        sns.heatmap(attention_matrix, annot=True, fmt='.3f', cmap='Blues',
                   xticklabels=src_words, yticklabels=tgt_words)
        plt.title(f'交叉注意力权重\n源句: {src_sentence}\n目标句: {tgt_sentence}')
        plt.xlabel('源语言 (Key)')
        plt.ylabel('目标语言 (Query)')
        plt.tight_layout()
        plt.show()

# 可视化注意力权重示例
visualize_translation_attention(
    model, "hello world", "你好 世界", en_tokenizer, zh_tokenizer
)

visualize_translation_attention(
    model, "thank you", "谢谢 你", en_tokenizer, zh_tokenizer
)

## 6.8 模型评估和改进

让我们分析模型性能并讨论改进方向：

In [None]:
def evaluate_model(model, test_data, src_tokenizer, tgt_tokenizer):
    """
    评估模型性能
    """
    model.eval()
    
    correct_translations = 0
    total_translations = len(test_data)
    
    print("模型评估结果:")
    print("=" * 50)
    
    for en_sentence, expected_zh in test_data:
        predicted_zh = translate(model, en_sentence, src_tokenizer, tgt_tokenizer)
        
        # 简单的完全匹配评估
        is_correct = predicted_zh.strip() == expected_zh.strip()
        if is_correct:
            correct_translations += 1
        
        print(f"源句: {en_sentence}")
        print(f"期望: {expected_zh}")
        print(f"预测: {predicted_zh}")
        print(f"正确: {'✓' if is_correct else '✗'}")
        print("-" * 30)
    
    accuracy = correct_translations / total_translations
    print(f"\n整体准确率: {accuracy:.2%} ({correct_translations}/{total_translations})")
    
    return accuracy

# 评估模型
test_data = list(zip(en_sentences[:5], zh_sentences[:5]))
accuracy = evaluate_model(model, test_data, en_tokenizer, zh_tokenizer)

# 模型统计信息
print(f"\n模型统计信息:")
print(f"参数总量: {sum(p.numel() for p in model.parameters()):,}")
print(f"模型大小: {sum(p.numel() for p in model.parameters()) * 4 / (1024**2):.2f} MB")
print(f"训练轮数: {num_epochs}")
print(f"最终损失: {losses[-1]:.4f}")
print(f"测试准确率: {accuracy:.2%}")

## 6.9 训练技巧和优化策略

让我们讨论一些重要的训练技巧：

In [None]:
class LearningRateScheduler:
    """
    Transformer论文中的学习率调度器
    """
    def __init__(self, d_model, warmup_steps=4000):
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0
    
    def get_lr(self):
        self.step_num += 1
        arg1 = self.step_num ** (-0.5)
        arg2 = self.step_num * (self.warmup_steps ** (-1.5))
        return (self.d_model ** (-0.5)) * min(arg1, arg2)

# 可视化学习率调度
scheduler = LearningRateScheduler(d_model=128, warmup_steps=1000)
steps = list(range(1, 5000))
lrs = []

for step in steps:
    scheduler.step_num = step - 1
    lrs.append(scheduler.get_lr())

plt.figure(figsize=(10, 6))
plt.plot(steps, lrs, 'b-', linewidth=2)
plt.title('Transformer学习率调度策略')
plt.xlabel('训练步数')
plt.ylabel('学习率')
plt.grid(True, alpha=0.3)
plt.axvline(x=1000, color='r', linestyle='--', alpha=0.7, label='Warmup结束点')
plt.legend()
plt.show()

# 训练技巧总结
print("Transformer训练技巧和优化策略:")
print("=" * 40)
print("1. 学习率调度:")
print("   - 预热阶段：线性增加")
print("   - 后续阶段：按步数平方根衰减")
print()
print("2. 正则化技术:")
print("   - Dropout: 防止过拟合")
print("   - Layer Normalization: 稳定训练")
print("   - 梯度裁剪: 防止梯度爆炸")
print()
print("3. 权重初始化:")
print("   - Xavier/Glorot初始化")
print("   - 嵌入层使用较小的标准差")
print()
print("4. 训练策略:")
print("   - Label Smoothing: 提高泛化能力")
print("   - Beam Search: 提高解码质量")
print("   - 数据增强: 增加训练数据多样性")
print()
print("5. 效率优化:")
print("   - 梯度累积: 模拟大批次训练")
print("   - 混合精度训练: 减少显存占用")
print("   - 模型并行: 处理大模型")

## 总结

在这个教程中，我们实现了完整的Transformer模型：

### 🏗️ 架构实现：
1. **编码器**：多层编码器块的堆叠
2. **解码器**：包含自注意力和交叉注意力的解码器块
3. **完整模型**：编码器-解码器架构

### 💼 实际应用：
- **机器翻译任务**：英语到中文的简单翻译
- **端到端训练**：从数据预处理到模型训练
- **推理实现**：自回归生成过程

### 🔧 关键技巧：
- **掩码机制**：填充掩码和因果掩码
- **学习率调度**：预热和衰减策略
- **梯度裁剪**：稳定训练过程
- **权重初始化**：合适的参数初始化

### 📊 可视化分析：
- **训练过程**：损失曲线监控
- **注意力权重**：理解模型关注点
- **翻译质量**：定性和定量评估

### 🚀 扩展方向：
1. **更大的数据集**：使用真实的并行语料
2. **更复杂的分词**：BPE、SentencePiece等
3. **评估指标**：BLEU、ROUGE等自动评估
4. **优化策略**：Label Smoothing、Beam Search等
5. **效率提升**：模型压缩、知识蒸馏等

这个完整的实现展示了Transformer的强大能力，也为理解现代NLP模型（如BERT、GPT）奠定了基础。

### 下一步学习：
- [07-transformer-variants.ipynb](07-transformer-variants.ipynb) - Transformer变体和现代应用