In [1]:
%load_ext autoreload
%autoreload 2

import os
import zipfile

if not os.path.exists("datasets/"):
    with zipfile.ZipFile("Multi30K.zip", "r") as zip_ref:
        zip_ref.extractall()

---

## Tokenizer

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from transformer_tokenizer import Tokenizer

class Multi30KDataset(Dataset):
    def __init__(self, en_file, de_file, en_tokenizer, de_tokenizer):
        self.en_tokenizer = en_tokenizer
        self.de_tokenizer = de_tokenizer
        with open(en_file, 'r', encoding='utf-8') as f:
            self.en_lines = f.readlines()
        with open(de_file, 'r', encoding='utf-8') as f:
            self.de_lines = f.readlines()
        assert len(self.en_lines) == len(self.de_lines), "English and German files must have the same number of lines."

    def __len__(self):
        return len(self.en_lines)

    def __getitem__(self, idx):
        en_sentence = self.en_lines[idx].strip()
        de_sentence = self.de_lines[idx].strip()
        en_tokens = self.en_tokenizer.tokenize(en_sentence)
        de_tokens = self.de_tokenizer.tokenize(de_sentence)
        return torch.tensor(en_tokens), torch.tensor(de_tokens)
    
def collate_fn(batch):
    en_batch, de_batch = zip(*batch)
    en_batch = torch.nn.utils.rnn.pad_sequence(en_batch, batch_first=True, padding_value=0)
    de_batch = torch.nn.utils.rnn.pad_sequence(de_batch, batch_first=True, padding_value=0)
    return en_batch, de_batch

en_tokenizer = Tokenizer(vocab_size=10000)
de_tokenizer = Tokenizer(vocab_size=10000)

with open("datasets/train/train.en", "r", encoding="utf-8") as f:
    en_lines = f.readlines()
    en_tokenizer.fit(en_lines)
with open("datasets/train/train.de", "r", encoding="utf-8") as f:
    de_lines = f.readlines()
    de_tokenizer.fit(de_lines)

dataset = Multi30KDataset(
    en_file="datasets/train/train.en",
    de_file="datasets/train/train.de",
    en_tokenizer=en_tokenizer,
    de_tokenizer=de_tokenizer
)

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)



In [3]:

for en_batch, de_batch in test_dataloader:
    print("\nText Check:")
    sample_en = "This is a sample English sentence."
    sample_de = "Dies ist ein Beispiel für einen deutschen Satz."
    print("English:")
    print("Original:", sample_en)
    print("English batch shape:", en_batch.shape)
    print("Tokenized:", en_tokenizer.tokenize(sample_en))
    print("Decoded:", en_tokenizer.detokenize(en_tokenizer.tokenize(sample_en)))
    
    print("\nGerman:")
    print("Original:", sample_de)
    print("German batch shape:", de_batch.shape)
    print("Tokenized:", de_tokenizer.tokenize(sample_de))
    print("Decoded:", de_tokenizer.detokenize(de_tokenizer.tokenize(sample_de)))
    break
    


Text Check:
English:
Original: This is a sample English sentence.
English batch shape: torch.Size([32, 22])
Tokenized: [2, 204, 10, 4, 2729, 6902, 1, 5, 3]
Decoded: <SOS> this is a sample english <UNK>. <EOS>

German:
Original: Dies ist ein Beispiel für einen deutschen Satz.
German batch shape: torch.Size([32, 21])
Tokenized: [2, 1212, 48, 5, 6263, 85, 19, 3856, 4250, 4, 3]
Decoded: <SOS> dies ist ein beispiel für einen deutschen satz. <EOS>


---

## Preprocess

In [4]:
from transformer import Transformer

# 1. Transformer 模型参数
vocab_size = en_tokenizer.get_vocab_size()
d_model = 512
num_heads = 8
num_layers = 2
d_ff = 2048
max_seq_len = 100
dropout = 0.1

# 2. 填充值索引
src_pad_idx = en_tokenizer.pad_token_id
tgt_pad_idx = de_tokenizer.pad_token_id

# 3. 初始化 Transformer
transformer = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout)

# 打印验证
print(f"Transformer initialized.")
print(f"Source padding index: {src_pad_idx}, Target padding index: {tgt_pad_idx}")
print(f"Vocabulary size: {vocab_size}")



Transformer initialized.
Source padding index: 0, Target padding index: 0
Vocabulary size: 10000


---

## Train

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm

vocab_size = en_tokenizer.get_vocab_size()
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_len = 100
dropout = 0.3
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
transformer = Transformer(vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_len, dropout).to(device)

# 1. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss(ignore_index=tgt_pad_idx)  # 忽略填充标记的损失
optimizer = optim.AdamW(transformer.parameters(), lr=1e-5)

# 2. 定义训练函数
def train_epoch(transformer, dataloader, criterion, optimizer, device):
    transformer.train()  # 切换到训练模式
    total_loss = 0
    progress_bar = tqdm(dataloader, desc="Training", leave=True)
    for batch in progress_bar:
        src, tgt = batch
        src, tgt = src.to(device), tgt.to(device)
        # 修正后的生成掩码
        tgt_input = tgt[:, :-1]
        tgt_target = tgt[:, 1:]
        # 构造掩码
        src_mask = transformer.make_src_mask(src, src_pad_idx)
        tgt_mask = transformer.make_trg_mask(tgt_input, tgt_pad_idx)  # 修正为 tgt_input
        # 前向传播
        output = transformer(src, tgt_input, src_mask, tgt_mask)
        # 调整输出形状以计算损失
        output = output.reshape(-1, vocab_size)
        tgt_target = tgt_target.reshape(-1)
        # 计算损失
        loss = criterion(output, tgt_target)
        # 反向传播与优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        # 在tqdm进度条中显示当前batch的loss
        progress_bar.set_postfix(batch_loss=loss.item())
    return total_loss / len(dataloader)

# 3. 定义验证函数
def validate_epoch(transformer, dataloader, criterion, device):
    transformer.eval()  # 切换到评估模式
    total_loss = 0
    progress_bar = tqdm(dataloader, desc="Validation", leave=True)
    with torch.no_grad():
        for batch in progress_bar:
            src, tgt = batch
            src, tgt = src.to(device), tgt.to(device)
            # 修正后的生成掩码
            tgt_input = tgt[:, :-1]
            tgt_target = tgt[:, 1:]
            # 构造掩码
            src_mask = transformer.make_src_mask(src, src_pad_idx)
            tgt_mask = transformer.make_trg_mask(tgt_input, tgt_pad_idx)  # 修正为 tgt_input
            # 前向传播
            output = transformer(src, tgt_input, src_mask, tgt_mask)
            # 调整输出形状以计算损失
            output = output.reshape(-1, vocab_size)
            tgt_target = tgt_target.reshape(-1)
            # 计算损失
            loss = criterion(output, tgt_target)
            total_loss += loss.item()
            # 在tqdm进度条中显示当前batch的loss
            progress_bar.set_postfix(batch_loss=loss.item())
    return total_loss / len(dataloader)

# 4. 定义训练主循环
def train_model(transformer, train_dataloader, val_dataloader, num_epochs, device, pretrain=None):
    if pretrain:
        transformer.load_state_dict(torch.load(pretrain))
        print(f"Loaded pre-trained model from {pretrain}")
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        epoch_loss = train_epoch(transformer, train_dataloader, criterion, optimizer, device)
        print(f"Training Loss: {epoch_loss:.4f}")
        val_loss = validate_epoch(transformer, val_dataloader, criterion, device)
        print(f"Validation Loss: {val_loss:.4f}")

# 5. 开始训练
transformer = transformer.to(device)
pretrain_path = 'transformer.pth'

In [None]:
train_model(transformer, train_dataloader, val_dataloader, num_epochs=100, device=device, pretrain=None)

In [None]:
torch.save(transformer.state_dict(), "transformer.pth")

---

## Eval

In [6]:
transformer.load_state_dict(torch.load('transformer.pth'))
transformer.eval()

Transformer(
  (preprocessor): TransformerPreprocessor(
    (embedding): Embedding(10000, 512)
  )
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (attention): MultiHeadAttention(
          (query): Linear(in_features=512, out_features=512, bias=True)
          (key): Linear(in_features=512, out_features=512, bias=True)
          (value): Linear(in_features=512, out_features=512, bias=True)
          (fc_out): Linear(in_features=512, out_features=512, bias=True)
        )
        (feed_forward): FeedForward(
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.3, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (layers): ModuleList(
      (0-5): 6 x DecoderLayer(
       