In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from tqdm.notebook import tqdm, trange
from torch import nn
import torch
from torch import Tensor
import numpy as np
from torch.utils.data import Dataset
import math
from transformers import AutoTokenizer

In [2]:
class MTTrainDataset(Dataset):
    
    def __init__(
        self, 
        train_path: str, 
        dic_path: str,
        en_tokenizer: AutoTokenizer,
        ch_tokenizer: AutoTokenizer,
        truncate: int=384,
        pad_multiple: int=8
    ):
        self.terms = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(dic_path).read().split("\n")[:-1]
        ]
        self.data = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in filter(
                lambda x: len(x) <= truncate,
                open(train_path).read().split("\n")[:-1]
            )
        ]
        self.en_tokenizer = en_tokenizer
        self.ch_tokenizer = ch_tokenizer
        self.en_tokenizer.add_tokens([
            term["en"] for term in self.terms
        ])
        self.ch_tokenizer.add_tokens([
            term["zh"] for term in self.terms
        ])
        self.pad_multiple = pad_multiple
                
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, index: int) -> dict:
        def pad(x, pad_multiple, pad_token_id=0):
            return x + [pad_token_id] * (pad_multiple - len(x) % pad_multiple)
        return {
            "en": pad(self.en_tokenizer.encode(self.data[index]["en"]), self.pad_multiple),
            "zh": pad(self.ch_tokenizer.encode(self.data[index]["zh"]), self.pad_multiple),
        }
    
    def get_raw(self, index: int) -> dict:
        return self.data[index]

In [3]:
def collect_fn(batch: dict) -> tuple[Tensor, Tensor]:
    # pad the batch
    pad_token_id = 0
    src = [torch.tensor(item["en"]) for item in batch]
    trg = [torch.tensor(item["zh"]) for item in batch]
    src = torch.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=pad_token_id)
    trg = torch.nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=pad_token_id)
    return src, trg

In [4]:
en_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="../../../cache")
ch_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese", cache_dir="../../../cache")

In [5]:
train_loader = torch.utils.data.DataLoader(
    MTTrainDataset(
        "./data/train.txt", 
        "./data/en-zh.dic",
        en_tokenizer,
        ch_tokenizer,
    ), 
    batch_size=2, 
    shuffle=True, 
    collate_fn=collect_fn
)

In [8]:
# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# Transformer
class TransformerModel(nn.Module):
    def __init__(self, src_tk, tgt_tk, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout):
        super(TransformerModel, self).__init__()
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
        self.src_embedding = nn.Embedding(len(src_tk), d_model)
        self.tgt_embedding = nn.Embedding(len(tgt_tk), d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout)
        self.fc_out = nn.Linear(d_model, len(tgt_tk))
        self.src_vocab = src_tk
        self.tgt_vocab = tgt_tk
        self.d_model = d_model

    def forward(self, src, tgt):
        # 调整src和tgt的维度
        src = src.transpose(0, 1)  # (seq_len, batch_size)
        tgt = tgt.transpose(0, 1)  # (seq_len, batch_size)

        src_mask = self.transformer.generate_square_subsequent_mask(src.size(0)).to(src.device)
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)

        src_padding_mask = (src == 0).transpose(0, 1)
        tgt_padding_mask = (tgt == 0).transpose(0, 1)

        src_embedded = self.positional_encoding(self.src_embedding(src) * math.sqrt(self.d_model))
        tgt_embedded = self.positional_encoding(self.tgt_embedding(tgt) * math.sqrt(self.d_model))

        output = self.transformer(src_embedded, tgt_embedded,
                                  src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, src_padding_mask)
        return self.fc_out(output).transpose(0, 1)

In [9]:
def initialize_model(src_tk, tgt_tk, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
    model = TransformerModel(src_tk, tgt_tk, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
    return model

In [11]:
device = "mps"

In [12]:
model = initialize_model(en_tokenizer, ch_tokenizer).to(device)



In [1]:
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
epoch_loss = 0
criterion = nn.CrossEntropyLoss(ignore_index=0)
for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
    src, tgt = batch
    if src.numel() == 0 or tgt.numel() == 0:
        continue
    
    src, tgt = src.to(device), tgt.to(device)
    
    optimizer.zero_grad()
    output = model(src, tgt[:, :-1])
    
    output_dim = output.shape[-1]
    output = output.contiguous().view(-1, output_dim)
    tgt = tgt[:, 1:].contiguous().view(-1)
    
    loss = criterion(output, tgt)
    loss.backward()

    optimizer.step()
    print(loss.item())

NameError: name 'model' is not defined