In [1]:
# 1. 加载数据
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from collections import Counter
import spacy
# 加载 WMT19 中英文翻译数据集
print("Loading dataset...")
dataset = load_dataset("wmt19", "zh-en")
train_data = dataset["train"]["translation"][:20000]
valid_data = dataset["validation"]["translation"][:1000]
print(dataset)
print((train_dataset[0], validation_dataset[0]))




DatasetDict({
    train: Dataset({
        features: ['translation'],
        num_rows: 1998814
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 3981
    })
})
({'en': 'It was emphasized that the cost-sharing arrangements in Nairobi between the United Nations Centre for Human Settlements, the United Nations Office at Nairobi and UNEP needed to be clarified and that savings resulting from the restructuring of the Centre, in particular in connection with the transfer of certain functions to the United Nations Office at Nairobi, needed to be identified.', 'zh': '有人强调,联合国人类住区中心、联合国内罗毕办事处和环境规划署之间在内罗毕订立的费用分担协议需要加以澄清,而且该中心调整后产生的节余,尤其是因为把某些职能转移给联合国内罗毕办事处而产生的节余都需要加以查明。'}, {'en': 'Last week, the broadcast of period drama “Beauty Private Kitchen” was temporarily halted, and accidentally triggered heated debate about faked ratings of locally produced dramas.', 'zh': '上周，古装剧《美人私房菜》临时停播，意外引发了关于国产剧收视率造假的热烈讨论。'})


In [2]:
# 2. 加载 Spacy 语言模型（用于分词）
spacy_zh = spacy.load("zh_core_web_sm")  # 中文分词
spacy_en = spacy.load("en_core_web_sm")  # 英文分词


# 句子分词函数
def tokenize_zh(text):
    return [tok.text for tok in spacy_zh.tokenizer(text)]


def tokenize_en(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]


# 3. 设定数据集的语言（来源：中文，目标：英文）
SRC_LANGUAGE = 'zh'
TGT_LANGUAGE = 'en'


# 4. 手动生成词汇表
def build_vocab_manual(data, tokenizer, min_freq=1):
    counter = Counter()

    # 统计词频
    for text in data:
        tokens = tokenizer(text)
        counter.update(tokens)

    # 构建词汇表
    vocab = {token: idx for idx, (token, freq) in enumerate(counter.items()) if freq >= min_freq}
    vocab["<unk>"] = len(vocab)
    vocab["<pad>"] = len(vocab)
    vocab["<sos>"] = len(vocab)
    vocab["<eos>"] = len(vocab)

    # 创建反向映射
    idx_to_token = {idx: token for token, idx in vocab.items()}

    return vocab, idx_to_token


# 提取中文和英文的句子
zh_texts = [item[SRC_LANGUAGE] for item in train_dataset]
en_texts = [item[TGT_LANGUAGE] for item in train_dataset]

# 构建词汇表
vocab_zh, idx_to_token_zh = build_vocab_manual(zh_texts, tokenize_zh)
vocab_en, idx_to_token_en = build_vocab_manual(en_texts, tokenize_en)

print(f"中文词汇表长度: {len(vocab_zh)}，英文词汇表长度: {len(vocab_en)}")


中文词汇表长度: 16955，英文词汇表长度: 18128


In [3]:
# 5. 定义数据处理函数（文本转 ID）
def numericalize(text, vocab, tokenizer):
    return [vocab["<sos>"]] + [vocab.get(token, vocab["<unk>"]) for token in tokenizer(text)] + [vocab["<eos>"]]

# 6. 填补空缺
def pad_sequences(seq, vocab, max_length = 50):
    return seq + [vocab['<pad>']] * (max_length - len(seq)) if len(seq) < max_length else seq[:max_length]


# 7. 构建 Dataset
class TranslationDataset(Dataset):
    def __init__(self, data, src_vocab, tgt_vocab, src_tokenizer, tgt_tokenizer, max_length=50):
        self.data = data
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src_text = self.data[idx]['zh']
        tgt_text = self.data[idx]['en']

        # 转化成ids
        src_ids = numericalize(src_text, self.src_vocab, self.src_tokenizer)
        tgt_ids = numericalize(tgt_text, self.tgt_vocab, self.tgt_tokenizer)
        # 填充序列
        src_ids = pad_sequences(src_ids, self.src_vocab, self.max_length)
        tgt_ids = pad_sequences(tgt_ids, self.tgt_vocab, self.max_length)

        return torch.tensor(src_ids), torch.tensor(tgt_ids)

# 7. 创建 DataLoader
train_dataset = TranslationDataset(train_dataset, vocab_zh, vocab_en, tokenize_zh, tokenize_en)
valid_dataset = TranslationDataset(validation_dataset, vocab_zh, vocab_en, tokenize_zh, tokenize_en)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda batch: batch)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=lambda batch: batch)
print("数据加载成功!")

Data loaded successfully!


In [4]:
# 8. Transformer 模型
# Transformer 模型
import numpy as np
class TransformerMT(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, embed_dim=256, num_heads=8, num_layers=6, ff_dim=512,
                 dropout=0.1, max_seq_len=50):
        super().__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, embed_dim)
        self.max_len = max_seq_len
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.ff_dim = ff_dim
        self.dropout = dropout
        self.transformer = nn.Transformer(
            d_model=embed_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True  # 启用 batch_first
        )

        self.fc_out = nn.Linear(embed_dim, tgt_vocab_size)
        self.src_pad_idx = vocab_zh["<pad>"]
        self.tgt_pad_idx = vocab_en["<pad>"]

    def create_positional_encoding(self, embed_dim, max_len):
        """生成固定的 Positional Encoding"""
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * -(np.log(10000.0) / embed_dim))
        pe = torch.zeros(max_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0).to(device)

    def make_src_mask(self, src):
        return (src == self.src_pad_idx)  # 形状: [batch_size, seq_len]

    def make_tgt_mask(self, tgt):
        seq_len = tgt.shape[1]
        tgt_sub_mask = torch.tril(
            torch.ones(seq_len, seq_len, dtype=torch.bool, device=tgt.device))  # [seq_len, seq_len]
        return tgt_sub_mask

    def forward(self, src, tgt):
        # [batch_size, seq_len, embed_dim]
        src_position = self.create_positional_encoding(self.embed_dim, src.shape[1])
        tgt_position = self.create_positional_encoding(self.embed_dim, tgt.shape[1])
        src1 = self.src_embedding(src) + src_position
        tgt1 = self.tgt_embedding(tgt) + tgt_position

        src_mask = self.make_src_mask(src)  # [batch_size, seq_len]
        tgt_mask = self.make_tgt_mask(tgt)  # [seq_len, seq_len]

        output = self.transformer(src1, tgt1, src_key_padding_mask=src_mask, tgt_mask=tgt_mask)  # [batch_size, seq_len, embed_dim]
        return self.fc_out(output)  # [batch_size, seq_len, tgt_vocab_size]


print("Transformer模型构建成功！")

Transformer模型构建成功！


In [5]:
# 9. 训练
import os
model_path = './model/Translate_transformer.pt'
def train_model(model, train_loader, epochs=10, lr=1e-4):
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss(ignore_index=vocab_en["<pad>"])
    # 检查模型是否已经存在
    if os.path.isfile(model_path):
        print(f"发现模型文件 {model_path}，正在加载...")
        model.load_state_dict(torch.load(model_path))
        print("模型加载完成")

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for num, batch in enumerate(train_loader):
            src, tgt = zip(*batch)
            src = torch.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=vocab_zh["<pad>"]).to(device)
            tgt = torch.nn.utils.rnn.pad_sequence(tgt, batch_first=True, padding_value=vocab_en["<pad>"]).to(device)

            optimizer.zero_grad()
            output = model(src, tgt[:, :-1])  # 目标序列去掉最后一个 token
            loss = loss_fn(output.reshape(-1, output.shape[-1]), tgt[:, 1:].reshape(-1))
            loss.backward()

            optimizer.step()

            total_loss += loss.item()
            if (num+1) % 125 == 0:
                print(f"\tBatch {num+1}/{len(train_loader)}, Loss: {loss.item():.4f}")

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}")
        print(translate(model, train_dataset[0]['zh']))
        torch.save(model.state_dict(), model_path)


# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerMT(len(vocab_zh), len(vocab_en)).to(device)
train_model(model, train_loader)


发现模型文件 ./model/Translate_transformer.pt，正在加载...
模型加载完成
	Batch 125/625, Loss: 0.2532
	Batch 250/625, Loss: 0.2953
	Batch 375/625, Loss: 0.3702
	Batch 500/625, Loss: 0.3052
	Batch 625/625, Loss: 0.2601
Epoch 1/1, Loss: 0.2757


In [7]:
# 10. 推理
def translate(model, src_sentence):
    model.eval()
    print(f"中文: {src_sentence}")
    src_ids = numericalize(src_sentence, vocab_zh, tokenize_zh)
    src_tensor = torch.tensor(src_ids).unsqueeze(0).to(device)

    tgt_ids = [vocab_en["<sos>"]]
    for _ in range(50):
        tgt_tensor = torch.tensor(tgt_ids).unsqueeze(0).to(device)
        output = model(src_tensor, tgt_tensor)
        next_word = output.argmax(dim=-1)[:, -1].item()
        if next_word == vocab_en["<eos>"]:
            break
        tgt_ids.append(next_word)

    return " ".join([idx_to_token_en[i] for i in tgt_ids[1:]])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerMT(len(vocab_zh), len(vocab_en)).to(device)
# 检查模型是否已经存在
if os.path.isfile(model_path):
    print(f"发现模型文件 {model_path}，正在加载...")
    model.load_state_dict(torch.load(model_path))
    print("模型加载完成")
print(translate(model, train_dataset[0]['zh']))

发现模型文件 ./model/Translate_transformer.pt，正在加载...
模型加载完成
They , the Secretary - Self - Self - Self - Self - Self - Self - Self - Self - Self - Self - Self - Self - up of the Secretary - up of the Secretary - up of the Secretary - up of the Secretary - up
