# Seq2Seq + Attention

- [论文pdf](https://arxiv.org/abs/1409.0473)

模型训练特别慢，大约慢了20倍，这是 Bahdanau Attention的设计缺陷，而 Transformer 是解决方案

## autodl 环境准备

In [1]:
# !pip install datasets transformers scikit-learn

In [2]:
import os

# os.environ['no_proxy'] = 'localhost,127.0.0.1,modelscope.com,aliyuncs.com,tencentyun.com,wisemodel.cn'
# os.environ['NO_PROXY'] = 'localhost,127.0.0.1,modelscope.com,aliyuncs.com,tencentyun.com,wisemodel.cn'

# os.environ['http_proxy'] = 'http://100.72.64.19:12798'
# os.environ['HTTP_PROXY'] = 'http://100.72.64.19:12798'

# os.environ['https_proxy'] = 'http://100.72.64.19:12798'
# os.environ['HTTPS_PROXY'] = 'http://100.72.64.19:12798'
os.environ['HF_ENDPOINT']="https://hf-mirror.com"

## 初始

In [3]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer

In [4]:
BATCH_SIZE = 128
NUM_EPOCHS = 6
MAX_DATA = 50000
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate=5e-4
weight_decay=1e-5

## 数据集

In [5]:
from datasets import load_dataset
from datasets.download import DownloadConfig


full_dataset = load_dataset("wmt/wmt17", "zh-en", split="train")

In [6]:
import re

def is_valid_trans(x) -> bool:
    en = x["translation"]["en"]
    zh = x["translation"]["zh"]
    if not en or not zh:
        return False
    if len(en) < 3 or len(zh) < 3:
        return False
    # 核心过滤条件
    if len(en) > 100 or len(zh) > 100:
        return False
    ratio = len(en) / len(zh)
    if ratio < 0.5 or ratio > 2:
        return False
    if not re.search(r'[\u4e00-\u9fff]', zh):
        return False
    return True

full_dataset = full_dataset.filter(is_valid_trans, num_proc=10)

dataset = full_dataset.select(range(min(MAX_DATA, len(full_dataset))))
dataset = dataset.shuffle(seed=20250709)

print(f"从 总共 {len(full_dataset)} 个数据中选择了 {len(dataset)} 个")

sample = dataset.select(range(3))
print("-"*100)
for i in sample:
    print(i["translation"]["en"])
    print(i["translation"]["zh"])
    # print("-"*100)

从 总共 1141860 个数据中选择了 50000 个
----------------------------------------------------------------------------------------------------
Salta Province, July 1995;
萨尔塔省，1995年7月；
Sadly, they are unlikely to end soon. Talk of a “grand bargain” remains just that – talk.
不幸的是，这些谈判能在短期内结束的希望颇为渺茫，关于“大妥协”事宜的谈判仍将仅仅停留在口头之上。
But ideas matter.
但是这个观念很重要。


In [7]:

class WmtDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer_en, tokenizer_zh, max_length=100):
        self.dataset = hf_dataset
        self.tokenizer_en = tokenizer_en
        self.tokenizer_zh = tokenizer_zh
        self.max_length = max_length
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        en_text = item["translation"]["en"]
        zh_text = item["translation"]["zh"]

        en_tokens = self.tokenizer_en(en_text, 
                                        max_length=self.max_length, 
                                        padding='max_length', 
                                        truncation=True, 
                                        return_tensors='pt')
            
        zh_tokens = self.tokenizer_zh(zh_text, 
                                        max_length=self.max_length, 
                                        padding='max_length', 
                                        truncation=True, 
                                        return_tensors='pt')
            
        return en_tokens['input_ids'].squeeze(), zh_tokens['input_ids'].squeeze(),
        

# bert-base-multilingual-cased 有11万单词，太大了
encoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
decoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
encoder_vocab_size = encoder_tokenizer.vocab_size
decoder_vocab_size = decoder_tokenizer.vocab_size

total = len(dataset)
split_idx = int(len(dataset)* 0.95)
train_hf_dataset = dataset.select(range(0, split_idx))
test_hf_dataset = dataset.select(range(split_idx, total))
train_data = WmtDataset(train_hf_dataset, tokenizer_en=encoder_tokenizer, tokenizer_zh=decoder_tokenizer)
val_data = WmtDataset(test_hf_dataset, tokenizer_en=encoder_tokenizer, tokenizer_zh=decoder_tokenizer)


train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=12,pin_memory=True if torch.cuda.is_available() else False)
val_iter = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=12,pin_memory=True if torch.cuda.is_available() else False)


## 模型

In [8]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size=512, hidden_size=1024, num_layers=2, dropout=0.1):
        super(Encoder, self).__init__()
        
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # Embedding layer to convert token IDs to dense vectors
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        
        # GRU layer for processing sequences
        self.rnn = nn.GRU(embed_size, hidden_size, num_layers, 
                        batch_first=True, dropout=dropout, bidirectional=False)
        
    def forward(self, input_seq):
        # Convert token IDs to embeddings
        embedded = self.embedding(input_seq)  # [batch_size, seq_len, embed_size]
        
        # Pass through GRU
        outputs, hidden = self.rnn(embedded)
        
        # outputs: [batch_size, seq_len, hidden_size]
        # hidden: [num_layers, batch_size, hidden_size] 
        
        return outputs, hidden

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        # e_ij = alpha(s_{t-1}, h_j)
        # 这里是： W_a*h_j + U_a*s_{t-1}
        self.W_a = nn.Linear(hidden_size, hidden_size, bias=False)
        self.U_a = nn.Linear(hidden_size, hidden_size, bias=False)
        self.v_a = nn.Linear(hidden_size, 1)
    def forward(self, hidden:torch.Tensor, encoder_out:torch.Tensor):
        # hidden: [num_layer, batch_size, hidden_size]
        # encoder_out: [batch_size, seq_len, hidden_size]
        hidden = hidden[-1].unsqueeze(1) # [batch_size, 1, hidden_size]
        energy1 = self.W_a(encoder_out) # [batch_size, 1, hidden_size]
        energy2 = self.U_a(hidden) # [batch_size, seq_len, hidden_size]
        energy = F.tanh(energy1 + energy2)
        energy = self.v_a(energy) # [batch_size, seq_len, 1]
        # [batch_size, seq_len, 1] --> # [batch_size, 1, seq_len]
        attn_weights = F.softmax(energy, dim=1).transpose(1, 2)
        return attn_weights

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size=512, hidden_size=1024, num_layers=2, dropout=0.1):
        super(Decoder, self).__init__()

        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        # 需要 嵌入表示信息、注意力信息
        self.rnn = nn.GRU(embed_size + hidden_size, hidden_size, num_layers, 
                        batch_first=True, dropout=dropout, bidirectional=False)
        self.output_projection = nn.Linear(hidden_size, vocab_size)
        self.attention = Attention(hidden_size)
        
    def forward(self, input_token, hidden, encoder_outs):
        # input_token: [batch_size, seq_len]
        # hidden: [num_layer, batch_size, hidden_size]
        
        # encoder_out: [batch_size, seq_len, hidden_size]
        embedded = self.embedding(input_token)  # [batch_size, seq_len, embed_size]
        # 计算注意力权重
        attn_weights = self.attention.forward(hidden, encoder_outs)
        context = torch.bmm(attn_weights, encoder_outs)
        # rnn_input: [batch_size, 1, embed_size + hidden_size]
        rnn_input = torch.cat([embedded, context], dim=2)
        rnn_out, hidden = self.rnn(rnn_input, hidden)
        outputs = self.output_projection(rnn_out)
        return outputs, hidden

class Seq2Seq(nn.Module):
    def __init__(self, encoder_vocab_size, decoder_vocab_size, embed_size=512, 
                hidden_size=1024, num_layers=2, dropout=0.1):
        super(Seq2Seq, self).__init__()
        
        self.encoder_vocab_size = encoder_vocab_size
        self.decoder_vocab_size = decoder_vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Initialize encoder and decoder
        self.encoder = Encoder(encoder_vocab_size, embed_size, hidden_size, num_layers, dropout)
        # self.attention = nn.Linear()
        self.decoder = Decoder(decoder_vocab_size, embed_size, hidden_size, num_layers, dropout)
        
    def forward(self, source_seq:torch.Tensor, target_seq:torch.Tensor):
        batch_size, tgt_len = target_seq.shape
        all_outs = torch.zeros(batch_size, tgt_len, self.decoder_vocab_size).to(source_seq.device)
        encoder_outs, hidden = self.encoder(source_seq)

        decoder_input = target_seq[:, 0].unsqueeze(1)
        for i in range(tgt_len):
            outs, hidden = self.decoder.forward(decoder_input, hidden, encoder_outs)
            all_outs[:, i, :] = outs.squeeze(1)
            # teacher forcing
            if i < tgt_len - 1:
                decoder_input = target_seq[:, i+1].unsqueeze(1)

        return all_outs
    
    def generate(self, source_seq, max_length=100, start_token=101, end_token=102):
        self.eval()
        batch_size = source_seq.size(0)
        
        with torch.no_grad():
            # Encode source sequence
            _, hidden = self.encoder(source_seq)
            
            # Initialize with start token
            decoder_input = torch.full((batch_size, 1), start_token, dtype=torch.long).to(source_seq.device)
            
            # Store generated tokens
            generated_tokens = []
            
            for _ in range(max_length):
                # (bs, 1, vocab), (bs, l, hidden)
                output, hidden = self.decoder(decoder_input, hidden)
                
                # Get the token with highest probability
                next_token = output.argmax(dim=2)
                generated_tokens.append(next_token)
                
                # Use predicted token as next input
                decoder_input = next_token
                
                # Stop if all sequences generated EOS token
                if torch.all(next_token == end_token):
                    break
            
            # Concatenate all generated tokens
            generated_seq = torch.cat(generated_tokens, dim=1)
            
        return generated_seq

In [9]:
model = Seq2Seq(encoder_vocab_size, decoder_vocab_size, embed_size=128, hidden_size=256, num_layers=2, dropout=0.2).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"总参数: {total_params:,}")
print(f"可训练参数: {trainable_params:,}")
print(f"占用空间: {total_params * 4 / 1024 / 1024:.2f} MB")

总参数: 13,751,433
可训练参数: 13,751,433
占用空间: 52.46 MB


## 训练

In [10]:
# 半精度训练
from collections import defaultdict
from tqdm.notebook import tqdm

from torch import autocast, GradScaler
scaler = GradScaler()


criterion = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')
optimizer = optim.Adam(model.parameters(), 
                            lr=learning_rate, 
                            weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

history = defaultdict(list)

def train_epoch():
    model.train()
    train_loss = 0
    progress_bar = tqdm(train_iter, desc="Train")
    
    for i, (x, y) in enumerate(progress_bar):
        x = x.to(device)
        y = y.to(device)
        
        with autocast(device, dtype=torch.bfloat16):
            pred = model(x, y[:,:-1])
            loss = criterion(pred.permute(0, 2, 1), y[:, 1:])
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
            
        # pred = model(x, y[:,:-1])
        # loss = criterion(pred.permute(0, 2, 1), y[:, 1:])
        # loss.backward()
        optimizer.zero_grad()

        train_loss += loss.item()
        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})
    
    return train_loss / len(train_iter)

def test_epoch():
    model.eval()
    val_loss = 0
    
    with torch.no_grad(): 
        progress_bar = tqdm(val_iter, desc="Validation")
        
        for x, y in progress_bar:
            x = x.to(device)
            y = y.to(device)        
            
            pred = model(x, y[:,:-1]) 
            loss = criterion(pred.permute(0, 2, 1), y[:, 1:]) 
            val_loss += loss.item()
            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})

    return val_loss / len(val_iter)


epoch= 0


Train:   0%|          | 0/372 [00:00<?, ?it/s]

Validation:   0%|          | 0/20 [00:00<?, ?it/s]

EP1/6, Train Loss:5.894558247699533, Val Loss:5.517913913726806
Saved best model with validation loss: 5.5179
epoch= 1


Train:   0%|          | 0/372 [00:00<?, ?it/s]

Validation:   0%|          | 0/20 [00:00<?, ?it/s]

EP2/6, Train Loss:4.9056418044592744, Val Loss:4.402488088607788
Saved best model with validation loss: 4.4025
epoch= 2


Train:   0%|          | 0/372 [00:00<?, ?it/s]

Validation:   0%|          | 0/20 [00:00<?, ?it/s]

EP3/6, Train Loss:3.9667646461917507, Val Loss:3.687987804412842
Saved best model with validation loss: 3.6880
epoch= 3


Train:   0%|          | 0/372 [00:00<?, ?it/s]

Validation:   0%|          | 0/20 [00:00<?, ?it/s]

EP4/6, Train Loss:3.365915233729988, Val Loss:3.2367999792099
Saved best model with validation loss: 3.2368
epoch= 4


Train:   0%|          | 0/372 [00:00<?, ?it/s]

Validation:   0%|          | 0/20 [00:00<?, ?it/s]

EP5/6, Train Loss:2.9838629576467697, Val Loss:2.9496025681495666
Saved best model with validation loss: 2.9496
epoch= 5


Train:   0%|          | 0/372 [00:00<?, ?it/s]

Validation:   0%|          | 0/20 [00:00<?, ?it/s]

EP6/6, Train Loss:2.7182633553140905, Val Loss:2.7491687417030333
Saved best model with validation loss: 2.7492


In [None]:


best_val_loss = np.inf
for epoch in range(NUM_EPOCHS):
    print("epoch=", epoch)
    train_loss = train_epoch()
    history['train_loss'].append(train_loss)

    val_loss = test_epoch()
    history['val_loss'].append(val_loss)
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']

    print("EP{}/{}, Train Loss:{}, Val Loss:{}".format(epoch+1, NUM_EPOCHS, history['train_loss'][-1], history['val_loss'][-1]))

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_seq2seq_attention.pth")
        print(f"Saved best model with validation loss: {best_val_loss:.4f}")
    

## 可视化

## 翻译

In [11]:
def translate(
    model: Seq2Seq,
    english_text: str,
    encoder_tokenizer,
    decoder_tokenizer,
    max_length: int = 50,
):
    """
    使用指定的Seq2Seq模型将英文文本翻译为中文。
    """
    model.eval()

    input_token = encoder_tokenizer(
        english_text,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    source_seq = input_token["input_ids"].to(device)

    start_token_id = decoder_tokenizer.cls_token_id
    end_token_id = decoder_tokenizer.sep_token_id

    generated_ids = model.generate(
        source_seq,
        max_length=max_length,
        start_token=start_token_id,
        end_token=end_token_id,
    )
    generated_text = decoder_tokenizer.decode(
        generated_ids[0], skip_special_tokens=True
    )

    # BERT中文tokenizer解码后可能在字与字之间添加空格
    return generated_text.replace(" ", "")

In [12]:
model = Seq2Seq(encoder_vocab_size, decoder_vocab_size, embed_size=128, hidden_size=256, num_layers=2, dropout=0.2).to(device)


state_dict = torch.load("./best_seq2seq_attention.pth", weights_only=True)
model.load_state_dict(state_dict)

speed_test_cases = [
    "Hello.",
    "How are you doing today?",
    "I really enjoy reading books and learning about different cultures around the world.",
    "The quick brown fox jumps over the lazy dog while the sun is shining brightly in the clear blue sky."
]

for sentence in speed_test_cases:
    print(sentence, translate(model, sentence, encoder_tokenizer, decoder_tokenizer))

TypeError: Decoder.forward() missing 1 required positional argument: 'encoder_outs'