In [None]:
from pathlib import Path

import torch
from torch import Tensor, nn

from llm_from_scratch.transformer.transformer import Encoder, Transformer

In [None]:
data_dir = Path("small_parallel_enja")
if not data_dir.exists():
    !git clone https://github.com/odashi/small_parallel_enja.git {data_dir}

train_ja = data_dir / "train.ja.000"
train_en = data_dir / "train.en.000"

In [None]:
from typing import Callable, Iterator
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator


tokenizer_ja = get_tokenizer(None)  # split するだけ
tokenizer_en = get_tokenizer(tokenizer="basic_english")  # lower して split

Tokenizer = Callable[[str], list[str]]


def iter_corpus(
    path: Path,
    tokenizer: Tokenizer,
    bos: str | None = "<bos>",
    eos: str | None = "<eos>",
) -> Iterator[list[str]]:
    with path.open("r") as f:
        for line in f:
            if bos:
                line = bos + " " + line
            if eos:
                line = line + " " + eos
            yield tokenizer(line)


train_tokens_ja = [tokens for tokens in iter_corpus(train_ja, tokenizer_ja)]
train_tokens_en = [tokens for tokens in iter_corpus(train_en, tokenizer_en)]

vocab_ja = build_vocab_from_iterator(
    iterator=train_tokens_ja,
    specials=("<unk>", "<pad>", "<sos>", "<eos>"),
)
vocab_ja.set_default_index(vocab_ja["<unk>"])
vocab_en = build_vocab_from_iterator(
    iterator=train_tokens_en,
    specials=("<unk>", "<pad>", "<sos>", "<eos>"),
)
vocab_en.set_default_index(vocab_en["<unk>"])

In [None]:
from torch.utils.data import DataLoader
from torchtext import transforms

src_transforms = transforms.Sequential(
    transforms.VocabTransform(vocab_ja),
    transforms.ToTensor(padding_value=vocab_ja["<pad>"]),
)
tgt_transforms = transforms.Sequential(
    transforms.VocabTransform(vocab_en),
    transforms.ToTensor(padding_value=vocab_en["<pad>"]),
)


def collate_fn(batch: Tensor) -> tuple[Tensor, Tensor]:
    src_texts, tgt_texts = [], []
    for s, t in batch:
        src_texts.append(s)
        tgt_texts.append(t)

    src_texts = src_transforms(src_texts)
    tgt_texts = tgt_transforms(tgt_texts)

    return src_texts, tgt_texts

In [None]:
# train_dataset = TranslationDataset(train_tokens_ja, train_tokens_en)
train_loader = DataLoader(
    list(zip(train_tokens_ja, train_tokens_en)),
    batch_size=16,
    shuffle=True,
    collate_fn=collate_fn,
)

In [None]:
PAD_ID = vocab_ja["<pad>"]

In [None]:
def create_padding_mask(pad_id: int, batch_tokens: Tensor):
    mask = batch_tokens == pad_id
    mask = mask.unsqueeze(1)
    return mask


def create_subsequent_mask(batch_tokens: Tensor):
    sequence_len = batch_tokens.size(1)
    mask = torch.triu(
        torch.full((sequence_len, sequence_len), 1),
        diagonal=1,
    )
    mask = mask == 1
    mask = mask.unsqueeze(0)
    return mask

In [None]:
max_len_ja = len(max(train_tokens_ja, key=lambda x: len(x)))
max_len_en = len(max(train_tokens_en, key=lambda x: len(x)))
max_length = max(max_len_ja, max_len_en)

In [None]:
embedding_dim = 512
n_blocks = 6
n_heads = 8
expansion_rate = 4
src_vocab_size = len(vocab_ja)
tgt_vocab_size = len(vocab_en)

model = Transformer(
    src_vocab_size,
    tgt_vocab_size,
    max_sequence_len=max_length,
    d_model=embedding_dim,
    n_blocks=n_blocks,
    n_heads=n_heads,
    d_k=embedding_dim,
    d_v=embedding_dim,
    d_ff=embedding_dim * expansion_rate,
)

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=PAD_ID)
lr = 0.001  # learning rate
n_epochs = 100
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10.0, gamma=0.95)

In [None]:
from tqdm.auto import tqdm


def train(model: nn.Module, log_interval: int = 10):
    model.train()
    total_loss = 0
    for i, (src_texts, tgt_texts) in enumerate(train_loader):
        src_mask = create_padding_mask(PAD_ID, src_texts)
        tgt_mask1 = create_padding_mask(PAD_ID, tgt_texts)
        tgt_mask2 = create_subsequent_mask(tgt_texts)
        tgt_mask = tgt_mask1 + tgt_mask2

        out = model(src_texts, tgt_texts, src_mask, tgt_mask, src_mask)
        out_flat = out.view(-1, tgt_vocab_size)
        tgt_flat = tgt_texts.view(-1)
        loss = criterion(out_flat, tgt_flat)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (i + 1) % log_interval == 0:
            print(f"step {i+1}: train loss = {loss.item()}")


pbar = tqdm(total=n_epochs)
for epoch in range(n_epochs):
    pbar.update(1)
    pbar.set_description(desc="Epoch")
    train(model)