In [1]:
import torch
import torch.nn as nn
from lib.Transformer import Transformer
from lib.TextDataset import TextDataset
from transformers import PreTrainedTokenizerFast
from torch.utils.data import DataLoader
from lib.lib import load_dataset, collate_fn, generate_story_end

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
N = 0.1
EPOCHS = 20
batch_size = 8

In [4]:
vocab_size = 60000
d_model = 512
n_heads = 8
d_ff = 2048
n_encoder_layers=4
n_decoder_layers=6
max_len = 500
pad_idx = 0
dropout = 0.1

In [5]:
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="../models/bpe_tokenizer/tokenizer.json",
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    additional_special_tokens=["<|endoftext|>"]
)

In [6]:
train_data = load_dataset("../data/train.txt", tokenizer, data_fraction=N)
test_data  = load_dataset("../data/test.txt", tokenizer, data_fraction=N)

pad_id = tokenizer.pad_token_id
train_loader = DataLoader(TextDataset(train_data), batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, pad_id))

In [7]:
transformer = Transformer(vocab_size=vocab_size, d_model=d_model, n_heads=n_heads, d_ff=d_ff, n_encoder_layers=n_encoder_layers, n_decoder_layers=n_decoder_layers, max_len=max_len, pad_idx=pad_idx, dropout=dropout).to(device)

In [None]:
optimizer = torch.optim.AdamW(transformer.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)

: 

In [None]:
for epoch in range(EPOCHS):
    transformer.train()
    total_loss = 0
    for src, tgt in train_loader:
        src, tgt = src.to(device), tgt.to(device)

        logits = transformer(src, tgt[:, :-1])
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt[:, 1:].reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}: loss = {total_loss / len(train_loader):.4f}")

In [None]:
save_path = "../models/transformer/checkpoints/transformer_tinystories.pt"
torch.save({
    "model_state_dict": transformer.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "epoch": epoch + 1,
    "loss": total_loss / len(train_loader),
    "config": {
        "vocab_size": transformer.output_linear.out_features,
        "d_model": transformer.d_model,
        "pad_idx": transformer.pad_idx,
    }
}, save_path)

In [None]:
checkpoint = torch.load("../models/transformer/checkpoints/transformer_tinystories.pt", map_location=device)

transformer = Transformer(
    vocab_size=checkpoint["config"]["vocab_size"],
    d_model=checkpoint["config"]["d_model"],
    n_heads=n_heads,
    d_ff=d_ff,
    n_encoder_layers=n_encoder_layers,
    n_decoder_layers=n_decoder_layers,
    max_len=max_len,
    pad_idx=checkpoint["config"]["pad_idx"],
    dropout=dropout
).to(device)

transformer.load_state_dict(checkpoint["model_state_dict"])

In [None]:
print(generate_story_end(transformer, tokenizer, "Once upon a time, a cat met a dog", device))