# 增加了 Beam Search的 Seq2Seq 模型

## autodl 环境准备

In [1]:
# !pip install datasets transformers scikit-learn

In [2]:
import os

# os.environ['no_proxy'] = 'localhost,127.0.0.1,modelscope.com,aliyuncs.com,tencentyun.com,wisemodel.cn'
# os.environ['NO_PROXY'] = 'localhost,127.0.0.1,modelscope.com,aliyuncs.com,tencentyun.com,wisemodel.cn'

# os.environ['http_proxy'] = 'http://100.72.64.19:12798'
# os.environ['HTTP_PROXY'] = 'http://100.72.64.19:12798'

# os.environ['https_proxy'] = 'http://100.72.64.19:12798'
# os.environ['HTTPS_PROXY'] = 'http://100.72.64.19:12798'
os.environ['HF_ENDPOINT']="https://hf-mirror.com"

## 初始

In [3]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer


In [4]:
BATCH_SIZE = 256
NUM_EPOCHS = 6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
learning_rate=5e-4
weight_decay=1e-5


## 数据集

In [5]:
from datasets import load_dataset
from datasets.download import DownloadConfig


full_dataset = load_dataset("wmt/wmt17", "zh-en", split="train")

In [6]:
import re

def is_valid_trans(x) -> bool:
    en = x["translation"]["en"]
    zh = x["translation"]["zh"]
    if not en or not zh:
        return False
    if len(en) < 3 or len(zh) < 3:
        return False
    # 核心过滤条件
    if len(en) > 100 or len(zh) > 100:
        return False
    ratio = len(en) / len(zh)
    if ratio < 0.5 or ratio > 2:
        return False
    if not re.search(r'[\u4e00-\u9fff]', zh):
        return False
    return True

full_dataset = full_dataset.filter(is_valid_trans, num_proc=10)

dataset = full_dataset.select(range(min(500_000, len(full_dataset))))
dataset = dataset.shuffle(seed=20250709)

print(f"从 总共 {len(full_dataset)} 个数据中选择了 {len(dataset)} 个")

sample = dataset.select(range(3))
print("-"*100)
for i in sample:
    print(i["translation"]["en"])
    print(i["translation"]["zh"])
    # print("-"*100)

从 总共 1141860 个数据中选择了 500000 个
----------------------------------------------------------------------------------------------------
meeting on 18 November 2005
审查会议于2005年11月18日
It had before it document FCCC/SBI/2003/15 and Add.1.
它收到了FCCC/SBI/2003/15和Add.1号文件。
21-23 September 2005
2005年9月21日-23日


In [7]:

class WmtDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer_en, tokenizer_zh, max_length=100):
        self.dataset = hf_dataset
        self.tokenizer_en = tokenizer_en
        self.tokenizer_zh = tokenizer_zh
        self.max_length = max_length
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        en_text = item["translation"]["en"]
        zh_text = item["translation"]["zh"]

        en_tokens = self.tokenizer_en(en_text, 
                                        max_length=self.max_length, 
                                        padding='max_length', 
                                        truncation=True, 
                                        return_tensors='pt')
            
        zh_tokens = self.tokenizer_zh(zh_text, 
                                        max_length=self.max_length, 
                                        padding='max_length', 
                                        truncation=True, 
                                        return_tensors='pt')
            
        return en_tokens['input_ids'].squeeze(), zh_tokens['input_ids'].squeeze(),
        

# bert-base-multilingual-cased 有11万单词，太大了
encoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
decoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
encoder_vocab_size = encoder_tokenizer.vocab_size
decoder_vocab_size = decoder_tokenizer.vocab_size

total = len(dataset)
split_idx = int(len(dataset)* 0.95)
train_hf_dataset = dataset.select(range(0, split_idx))
test_hf_dataset = dataset.select(range(split_idx, total))
train_data = WmtDataset(train_hf_dataset, tokenizer_en=encoder_tokenizer, tokenizer_zh=decoder_tokenizer)
val_data = WmtDataset(test_hf_dataset, tokenizer_en=encoder_tokenizer, tokenizer_zh=decoder_tokenizer)


train_iter = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=12,pin_memory=True if torch.cuda.is_available() else False)
val_iter = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=12,pin_memory=True if torch.cuda.is_available() else False)


## 模型

In [None]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size=512, hidden_size=1024, num_layers=2, dropout=0.1):
        super(Encoder, self).__init__()
        
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # Embedding layer to convert token IDs to dense vectors
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        
        # GRU layer for processing sequences
        self.rnn = nn.GRU(embed_size, hidden_size, num_layers, 
                        batch_first=True, dropout=dropout, bidirectional=False)
        
    def forward(self, input_seq):
        # Convert token IDs to embeddings
        embedded = self.embedding(input_seq)  # [batch_size, seq_len, embed_size]
        
        # Pass through GRU
        outputs, hidden = self.rnn(embedded)
        
        # outputs: [batch_size, seq_len, hidden_size]
        # hidden: [num_layers, batch_size, hidden_size] 
        
        return outputs, hidden

class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size=512, hidden_size=1024, num_layers=2, dropout=0.1):
        super(Decoder, self).__init__()
        
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        # Embedding layer for target tokens
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        
        # GRU layer for generating sequences
        self.rnn = nn.GRU(embed_size, hidden_size, num_layers, 
                        batch_first=True, dropout=dropout, bidirectional=False)
        
        # Output projection layer to vocabulary
        self.output_projection = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, input_token, hidden):
        embedded = self.embedding(input_token)  # [batch_size, seq_len, embed_size]
        
        # Process entire sequence in parallel
        gru_out, final_hidden = self.rnn(embedded, hidden)
        # gru_out: [batch_size, seq_len, hidden_size]
        
        # Project to vocabulary for all timesteps
        outputs = self.output_projection(gru_out)  # [batch_size, seq_len, vocab_size]
        
        return outputs, final_hidden


class Seq2Seq(nn.Module):
    def __init__(self, encoder_vocab_size, decoder_vocab_size, embed_size=512, 
                hidden_size=1024, num_layers=2, dropout=0.1):
        super(Seq2Seq, self).__init__()
        
        self.encoder_vocab_size = encoder_vocab_size
        self.decoder_vocab_size = decoder_vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Initialize encoder and decoder
        self.encoder = Encoder(encoder_vocab_size, embed_size, hidden_size, num_layers, dropout)
        self.decoder = Decoder(decoder_vocab_size, embed_size, hidden_size, num_layers, dropout)
        
    def forward(self, source_seq, target_seq):
        
        _, hidden = self.encoder(source_seq)
        
        outputs, _ = self.decoder(target_seq, hidden)
        
        return outputs
    
    def generate(self, source_seq, beam_size,  max_length=100, start_token=101, end_token=102):
        self.eval()
        batch_size = source_seq.size(0)
        with torch.no_grad():
            # Encode source sequence
            _, hidden = self.encoder(source_seq)
            
            # Initialize with start token
            decoder_input = torch.full((batch_size, 1), start_token, dtype=torch.long).to(source_seq.device)
            
            # Store generated tokens
            generated_tokens = []
            
            for _ in range(max_length):
                # (bs, 1, vocab), (bs, l, hidden)
                output, hidden = self.decoder(decoder_input, hidden)
                
                # Get the token with highest probability
                next_token = output.argmax(dim=2)
                generated_tokens.append(next_token)
                
                # Use predicted token as next input
                decoder_input = next_token
                
                # Stop if all sequences generated EOS token
                if torch.all(next_token == end_token):
                    break
            
            # Concatenate all generated tokens
            generated_seq = torch.cat(generated_tokens, dim=1)
            
        return generated_seq

In [9]:
model = Seq2Seq(encoder_vocab_size, decoder_vocab_size, embed_size=128, hidden_size=256, num_layers=2, dropout=0.2).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"总参数: {total_params:,}")
print(f"可训练参数: {trainable_params:,}")
print(f"占用空间: {total_params * 4 / 1024 / 1024:.2f} MB")

总参数: 13,423,496
可训练参数: 13,423,496
占用空间: 51.21 MB


## 训练

In [10]:
# 半精度训练
from collections import defaultdict
from tqdm.notebook import tqdm

from torch import autocast, GradScaler
scaler = GradScaler()


criterion = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')
optimizer = optim.Adam(model.parameters(), 
                            lr=learning_rate, 
                            weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

history = defaultdict(list)

def train_epoch():
    model.train()
    train_loss = 0
    progress_bar = tqdm(train_iter, desc="Train")
    
    for x, y in progress_bar:
        x = x.to(device)
        y = y.to(device)
        
        with autocast(device, dtype=torch.bfloat16):
            pred = model(x, y[:,:-1])
            loss = criterion(pred.permute(0, 2, 1), y[:, 1:])
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
            
        # pred = model(x, y[:,:-1])
        # loss = criterion(pred.permute(0, 2, 1), y[:, 1:])
        # loss.backward()
        # optimizer.zero_grad()

        train_loss += loss.item()
        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})
    
    return train_loss / len(train_iter)

def test_epoch():
    model.eval()
    val_loss = 0
    
    with torch.no_grad(): 
        progress_bar = tqdm(val_iter, desc="Validation")
        
        for x, y in progress_bar:
            x = x.to(device)
            y = y.to(device)        
            
            pred = model(x, y[:,:-1]) 
            loss = criterion(pred.permute(0, 2, 1), y[:, 1:]) 
            val_loss += loss.item()
            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}'})

    return val_loss / len(val_iter)


best_val_loss = np.inf
for epoch in range(NUM_EPOCHS):
    train_loss = train_epoch()
    history['train_loss'].append(train_loss)

    val_loss = test_epoch()
    history['val_loss'].append(val_loss)
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']

    print("EP{}/{}, Train Loss:{}, Val Loss:{}".format(epoch+1, NUM_EPOCHS, history['train_loss'][-1], history['val_loss'][-1]))

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_seq2seq.pth")
        print(f"Saved best model with validation loss: {best_val_loss:.4f}")
    

Train:   0%|          | 0/1856 [00:00<?, ?it/s]

Validation:   0%|          | 0/98 [00:00<?, ?it/s]

EP1/6, Train Loss:3.8344548018328073, Val Loss:2.4305096402460213
Saved best model with validation loss: 2.4305


Train:   0%|          | 0/1856 [00:00<?, ?it/s]

Validation:   0%|          | 0/98 [00:00<?, ?it/s]

EP2/6, Train Loss:2.1763440919718864, Val Loss:1.9013299954180816
Saved best model with validation loss: 1.9013


Train:   0%|          | 0/1856 [00:00<?, ?it/s]

Validation:   0%|          | 0/98 [00:00<?, ?it/s]

EP3/6, Train Loss:1.8303588089244118, Val Loss:1.67874306075427
Saved best model with validation loss: 1.6787


Train:   0%|          | 0/1856 [00:00<?, ?it/s]

Validation:   0%|          | 0/98 [00:00<?, ?it/s]

EP4/6, Train Loss:1.6462825958071083, Val Loss:1.548267908242284
Saved best model with validation loss: 1.5483


Train:   0%|          | 0/1856 [00:00<?, ?it/s]

Validation:   0%|          | 0/98 [00:00<?, ?it/s]

EP5/6, Train Loss:1.550671435892582, Val Loss:1.5419878509579872
Saved best model with validation loss: 1.5420


Train:   0%|          | 0/1856 [00:00<?, ?it/s]

Validation:   0%|          | 0/98 [00:00<?, ?it/s]

EP6/6, Train Loss:1.5466407470662018, Val Loss:1.541987846092302
Saved best model with validation loss: 1.5420


## 可视化

## 翻译

In [11]:
def translate(
    model: Seq2Seq,
    english_text: str,
    encoder_tokenizer,
    decoder_tokenizer,
    max_length: int = 50,
):
    """
    使用指定的Seq2Seq模型将英文文本翻译为中文。
    """
    model.eval()

    input_token = encoder_tokenizer(
        english_text,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    source_seq = input_token["input_ids"].to(device)

    start_token_id = decoder_tokenizer.cls_token_id
    end_token_id = decoder_tokenizer.sep_token_id

    generated_ids = model.generate(
        source_seq,
        max_length=max_length,
        start_token=start_token_id,
        end_token=end_token_id,
    )
    generated_text = decoder_tokenizer.decode(
        generated_ids[0], skip_special_tokens=True
    )

    # BERT中文tokenizer解码后可能在字与字之间添加空格
    return generated_text.replace(" ", "")

In [None]:
model = Seq2Seq(encoder_vocab_size, decoder_vocab_size, embed_size=128, hidden_size=256, num_layers=2, dropout=0.2).to(device)


state_dict = torch.load("./best_seq2seq.pth", weights_only=True)
model.load_state_dict(state_dict)

speed_test_cases = [
    "Hello.",
    "How are you doing today?",
    "I really enjoy reading books and learning about different cultures around the world.",
    "The quick brown fox jumps over the lazy dog while the sun is shining brightly in the clear blue sky."
]

for sentence in speed_test_cases:
    print(sentence, translate(model, sentence, encoder_tokenizer, decoder_tokenizer))

*。
如果是否是什么？
一个的一个案件。
的人口。


## 瓶颈点
- 只有50W数据
- 使用贪心搜索，没有beam search
- 没有注意力机制
- 没有预训练embedding- 