In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch import Tensor

In [2]:
from transformers import AutoTokenizer
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader

In [3]:
class MTTrainDataset(Dataset):
    
    def __init__(self, train_path, dic_path):
        self.terms = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(dic_path).read().split("\n")[:-1]
        ]
        self.data = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in filter(
                lambda x: len(x) < 512,
                open(train_path).read().split("\n")[:-1]
            )
        ]
        self.en_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="../../cache")
        self.ch_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese", cache_dir="../../cache")
        # self.en_tokenizer.add_tokens([
        #     term["en"] for term in self.terms
        # ])
        # self.ch_tokenizer.add_tokens([
        #     term["zh"] for term in self.terms
        # ])
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index) -> dict:
        return {
            "src": Tensor(self.en_tokenizer.encode(self.data[index]["en"])).to(device, dtype=torch.long), 
            "trg": Tensor(self.ch_tokenizer.encode(self.data[index]["zh"])).to(device, dtype=torch.long)
        }
    
    def get_raw(self, index):
        return self.data[index]

In [4]:
train_data = MTTrainDataset("./data/train.txt", "./data/en-zh.dic")

In [5]:
# 位置编码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        import math
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# Transformer
class TransformerModel(nn.Module):
    def __init__(self, src_tk, tgt_tk, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout):
        super(TransformerModel, self).__init__()
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
        self.src_embedding = nn.Embedding(len(src_tk), d_model)
        self.tgt_embedding = nn.Embedding(len(tgt_tk), d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout)
        self.fc_out = nn.Linear(d_model, len(tgt_tk))
        self.src_vocab = src_tk
        self.tgt_vocab = tgt_tk
        self.d_model = d_model

    def forward(self, src, tgt):
        # 调整src和tgt的维度
        src = src.transpose(0, 1)  # (seq_len, batch_size)
        tgt = tgt.transpose(0, 1)  # (seq_len, batch_size)

        src_mask = self.transformer.generate_square_subsequent_mask(src.size(0)).to(src.device)
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)

        src_padding_mask = (src == 0).transpose(0, 1)
        tgt_padding_mask = (tgt == 0).transpose(0, 1)

        src_embedded = self.positional_encoding(self.src_embedding(src) * math.sqrt(self.d_model))
        tgt_embedded = self.positional_encoding(self.tgt_embedding(tgt) * math.sqrt(self.d_model))

        output = self.transformer(src_embedded, tgt_embedded,
                                  src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, src_padding_mask)
        return self.fc_out(output).transpose(0, 1)

In [6]:
device = "cuda"

In [7]:
def initialize_model(src_tk, tgt_tk, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
    model = TransformerModel(src_tk, tgt_tk, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
    return model

In [8]:
en_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="../../../cache")
ch_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese", cache_dir="../../../cache")

In [9]:
import math

In [10]:
model = initialize_model(en_tokenizer, ch_tokenizer).to(device)



In [11]:
model.train()
pass

In [12]:
def collate_fn(batch):
    src = torch.nn.utils.rnn.pad_sequence([x["src"] for x in batch], batch_first=True, padding_value=0)
    trg = torch.nn.utils.rnn.pad_sequence([x["trg"] for x in batch], batch_first=True, padding_value=0)
    return src, trg

In [13]:
# set random values for the model
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [14]:
def train(
    epochs: int=10,
    steps: int | None=None,
    batch_size: int=4,
    logging_times: int=200,
    check_each_epoch_times: int=3,
) -> list[float]:
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, scheduler_steps, gamma=gemma)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    losses = []
    logging_steps = (steps if steps is not None else (len(train_data) // batch_size)) // logging_times
    check_steps = (steps if steps is not None else (len(train_data) // batch_size)) // check_each_epoch_times
    from tqdm.notebook import tqdm
    data_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn)
    for epoch in tqdm(range(epochs)):
        for step, (src, trg) in tqdm(
            enumerate(data_loader), total=len(data_loader) if steps is None else steps, desc=f"Epoch: {epoch}"
        ):
            src = src.to(device)
            trg = trg.to(device)
            optimizer.zero_grad()
            output = model(src, trg[:, :-1])
            loss = criterion(output.reshape(-1, output.size(-1)), trg[:, 1:].reshape(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            # scheduler.step()
            losses.append(loss.item())
            if steps is not None and step >= steps:
                break
            if step % logging_steps == 0:
                print(f"Epoch: {epoch}, Step: {step}, Loss: {losses[-1]}")
                print(f"avg_loss: {np.mean(losses[-logging_steps:])}")
                print(f"Input: {train_data.en_tokenizer.decode(src[0].tolist())}")
                print(f"Prediction: {train_data.ch_tokenizer.decode(output.argmax(-1)[0].tolist())}")
                print(f"Target: {train_data.ch_tokenizer.decode(trg[0].tolist())}")
                print("=" * 100)
            if step % check_steps == 0:
                torch.save(model.state_dict(), f"./m_{step}_{epoch}.pth")
    return losses

In [15]:
losses = train(epochs=3, steps=None, logging_times=100, batch_size=32)

  0%|          | 0/3 [00:00<?, ?it/s]

Epoch: 0:   0%|          | 0/4616 [00:00<?, ?it/s]



Epoch: 0, Step: 0, Loss: 9.982889175415039
avg_loss: 9.982889175415039
Input: [CLS] there's a tight and surprising link between the ocean's health and ours, says marine biologist stephen palumbi. he shows how toxins at the bottom of the ocean food chain find their way into our bodies, with a shocking story of toxic contamination from a japanese fish market. his work points a way forward for saving the oceans'health - - and humanity's. [SEP]
Prediction: 実 桨 假load 2012ーションetload酩load 矍ff 假膏et chan mems籠 鏈 86 鉚 鉚ａloadload 假 chan chan屑 outlet 蓓 組ff烷膏 戈ａａーションloadads 2012烷ａａ 恚 氤ーション陣鳳ａ湾etload 薦load 鉚load 2012 捋湾ffff 缱load 2012 chan陣 雹ffａ湾et总ーション頰 捋loadloadet 捋烷 chan烷湾loadloadload veload韜槍load chanload 蓓ａ exo酩 捋 蟋 chan湾load chanloadningａ 2012ａ槍loadet chan頰 捋ff 2012 chan淼 蓓
Target: [CLS] 生 物 学 家 史 蒂 芬 · 帕 伦 认 为 ， 海 洋 的 健 康 和 我 们 的 健 康 之 间 有 着 紧 密 而 神 奇 的 联 系 。 他 通 过 日 本 一 个 渔 场 发 生 的 让 人 震 惊 的 有 毒 污 染 的 事 件 ， 展 示 了 位 于 海 洋 食 物 链 底 部 的 有 毒 物 质 是 如 何 进 入 我 们 的 身 体 的 。 他 的 工 作 主 要 是 未 来 拯 救 海 洋 健

In [None]:
model.train()
pass