In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import GPT2Tokenizer
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# 1. 初始化GPT2 Tokenizer
en_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
de_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# 添加特殊标记
special_tokens = {"bos_token": "<sos>", "eos_token": "<eos>", "pad_token": "<pad>"}
en_tokenizer.add_special_tokens(special_tokens)
de_tokenizer.add_special_tokens(special_tokens)

# 2. 自定义数据集
class Multi30KDataset(Dataset):
    def __init__(self, en_path, de_path, en_tokenizer, de_tokenizer):
        self.en_sentences = self._read_file(en_path)
        self.de_sentences = self._read_file(de_path)
        self.en_tokenizer = en_tokenizer
        self.de_tokenizer = de_tokenizer
        assert len(self.en_sentences) == len(self.de_sentences), "数据不匹配！"

    def _read_file(self, path):
        with open(path, 'r', encoding='utf-8') as f:
            return [line.strip() for line in f]

    def __len__(self):
        return len(self.en_sentences)

    def __getitem__(self, idx):
        en_encoded = self.en_tokenizer(
            self.en_sentences[idx],
            return_tensors="pt",
            padding=False,
            truncation=True,
            add_special_tokens=True,
        )["input_ids"].squeeze(0)

        de_encoded = self.de_tokenizer(
            self.de_sentences[idx],
            return_tensors="pt",
            padding=False,
            truncation=True,
            add_special_tokens=True,
        )["input_ids"].squeeze(0)

        return en_encoded, de_encoded

# 3. 定义collate_fn
def collate_fn(batch):
    en_batch, de_batch = zip(*batch)
    en_batch = pad_sequence(en_batch, batch_first=True, padding_value=en_tokenizer.pad_token_id)
    de_batch = pad_sequence(de_batch, batch_first=True, padding_value=de_tokenizer.pad_token_id)
    return en_batch, de_batch

# 4. 初始化数据集和数据加载器
en_file_path = 'Multi30K/datasets/train/train.en'
de_file_path = 'Multi30K/datasets/train/train.de'

dataset = Multi30KDataset(en_file_path, de_file_path, en_tokenizer, de_tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

# 5. 测试数据加载器
for en_batch, de_batch in dataloader:
    print("English batch shape:", en_batch.shape)
    print("German batch shape:", de_batch.shape)
    print("English batch example (tokens):", en_batch[0])
    print("German batch example (tokens):", de_batch[0])
    print("Decoded English:", en_tokenizer.decode(en_batch[0]))
    print("Decoded German:", de_tokenizer.decode(de_batch[0]))
    break


English batch shape: torch.Size([32, 34])
German batch shape: torch.Size([32, 58])
English batch example (tokens): tensor([  464,  7586,  3290,   318,  2491,   319,   262,  8701,    13, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259])
German batch example (tokens): tensor([   36,   259,  8290, 38886,   367,   917,  8851,   429,   257,  3046,
         1902,   292,    13, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259,
        50259, 50259, 50259, 50259, 50259, 50259, 50259, 50259])
Decoded English: The brown dog is running on the grass.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa

In [3]:
from transformer import Transformer


# 1. Transformer 模型参数
vocab_size = len(en_tokenizer)
d_model = 512
num_heads = 8
num_layers = 2
d_ff = 2048
max_seq_len = 100
dropout = 0.1

# 2. 填充值索引
src_pad_idx = en_tokenizer.pad_token_id
tgt_pad_idx = de_tokenizer.pad_token_id

# 3. 初始化 Transformer
transformer = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout)

# 打印验证
print(f"Transformer initialized.")
print(f"Source padding index: {src_pad_idx}, Target padding index: {tgt_pad_idx}")
print(f"Vocabulary size: {vocab_size}")



Transformer initialized.
Source padding index: 50259, Target padding index: 50259
Vocabulary size: 50260


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

# 1. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=tgt_pad_idx)  # 忽略填充标记的损失
optimizer = optim.Adam(transformer.parameters(), lr=0.0001)

# 2. 定义训练函数
def train_epoch(transformer, dataloader, criterion, optimizer, device):
    transformer.train()  # 切换到训练模式
    total_loss = 0

    for batch in tqdm(dataloader, desc="Training"):
        src, tgt = batch
        src, tgt = src.to(device), tgt.to(device)
        
        # 修正后的生成掩码
        tgt_input = tgt[:, :-1]
        tgt_target = tgt[:, 1:]

        # 构造掩码
        src_mask = transformer.make_src_mask(src, src_pad_idx)
        tgt_mask = transformer.make_trg_mask(tgt_input, tgt_pad_idx)  # 修正为 tgt_input


        print(f"src shape: {src.shape}, tgt_input shape: {tgt_input.shape}, src_mask shape: {src_mask.shape}, tgt_mask shape: {tgt_mask.shape}")
        # 前向传播
        output = transformer(src, tgt_input, src_mask, tgt_mask)

        # 调整输出形状以计算损失
        output = output.reshape(-1, vocab_size)
        tgt_target = tgt_target.reshape(-1)

        # 计算损失
        loss = criterion(output, tgt_target)

        # 反向传播与优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

# 3. 定义训练主循环
def train_model(transformer, dataloader, num_epochs, device):
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        epoch_loss = train_epoch(transformer, dataloader, criterion, optimizer, device)
        print(f"Epoch Loss: {epoch_loss:.4f}")

# 4. 开始训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transformer = transformer.to(device)

num_epochs = 10
train_model(transformer, dataloader, num_epochs, device)


Epoch 1/10


Training:   0%|          | 0/907 [00:00<?, ?it/s]

src shape: torch.Size([32, 29]), tgt_input shape: torch.Size([32, 65]), src_mask shape: torch.Size([32, 1, 1, 29]), tgt_mask shape: torch.Size([32, 1, 65, 65])
mask shape in mha: torch.Size([32, 1, 65, 65])
scores shape in mha: torch.Size([32, 8, 65, 65])
mask shape in mha: torch.Size([32, 1, 1, 29])
scores shape in mha: torch.Size([32, 8, 65, 29])
mask shape in mha: torch.Size([32, 1, 65, 65])
scores shape in mha: torch.Size([32, 8, 65, 65])
mask shape in mha: torch.Size([32, 1, 1, 29])
scores shape in mha: torch.Size([32, 8, 65, 29])


Training:   0%|          | 1/907 [00:02<41:13,  2.73s/it]

src shape: torch.Size([32, 31]), tgt_input shape: torch.Size([32, 55]), src_mask shape: torch.Size([32, 1, 1, 31]), tgt_mask shape: torch.Size([32, 1, 55, 55])
mask shape in mha: torch.Size([32, 1, 55, 55])
scores shape in mha: torch.Size([32, 8, 55, 55])
mask shape in mha: torch.Size([32, 1, 1, 31])
scores shape in mha: torch.Size([32, 8, 55, 31])
mask shape in mha: torch.Size([32, 1, 55, 55])
scores shape in mha: torch.Size([32, 8, 55, 55])
mask shape in mha: torch.Size([32, 1, 1, 31])
scores shape in mha: torch.Size([32, 8, 55, 31])


Training:   0%|          | 2/907 [00:05<39:21,  2.61s/it]

src shape: torch.Size([32, 35]), tgt_input shape: torch.Size([32, 67]), src_mask shape: torch.Size([32, 1, 1, 35]), tgt_mask shape: torch.Size([32, 1, 67, 67])
mask shape in mha: torch.Size([32, 1, 67, 67])
scores shape in mha: torch.Size([32, 8, 67, 67])
mask shape in mha: torch.Size([32, 1, 1, 35])
scores shape in mha: torch.Size([32, 8, 67, 35])
mask shape in mha: torch.Size([32, 1, 67, 67])
scores shape in mha: torch.Size([32, 8, 67, 67])
mask shape in mha: torch.Size([32, 1, 1, 35])
scores shape in mha: torch.Size([32, 8, 67, 35])


Training:   0%|          | 2/907 [00:07<56:12,  3.73s/it]


KeyboardInterrupt: 