In [1]:
import torch
import torch.nn as nn
from lib.Transformer import Transformer
from lib.TextDataset import TextDataset
from transformers import PreTrainedTokenizerFast
from torch.utils.data import DataLoader
from lib.lib import load_dataset, collate_fn, generate_story_end

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
N = 0.0003
EPOCHS = 40
batch_size = 32

In [4]:
vocab_size = 60000
d_model = 512
n_heads = 8
d_ff = 2048
n_encoder_layers=4
n_decoder_layers=6
max_len = 1000
pad_idx = 0
dropout = 0.2

In [5]:
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="../models/bpe_tokenizer/tokenizer.json",
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    additional_special_tokens=["<|endoftext|>"]
)

In [6]:
train_data = load_dataset("../data/train.txt", tokenizer, data_fraction=N)
test_data  = load_dataset("../data/test.txt", tokenizer, data_fraction=N)

pad_id = tokenizer.pad_token_id
train_loader = DataLoader(TextDataset(train_data), batch_size=batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, pad_id, max_len))

In [7]:
transformer = Transformer(vocab_size=vocab_size, d_model=d_model, n_heads=n_heads, d_ff=d_ff, n_encoder_layers=n_encoder_layers, n_decoder_layers=n_decoder_layers, max_len=max_len, pad_idx=pad_idx, dropout=dropout).to(device)

In [8]:
optimizer = torch.optim.AdamW(transformer.parameters(), lr=3e-4)
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)

In [9]:
for epoch in range(EPOCHS):
    transformer.train()
    total_loss = 0
    for src, tgt in train_loader:
        src, tgt = src.to(device), tgt.to(device)

        logits = transformer(src, tgt[:, :-1])
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt[:, 1:].reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}: loss = {total_loss / len(train_loader):.4f}")

Epoch 1: loss = 8.3663
Epoch 2: loss = 6.1787
Epoch 3: loss = 5.7563
Epoch 4: loss = 5.2607
Epoch 5: loss = 4.8377
Epoch 6: loss = 4.5219
Epoch 7: loss = 4.2633
Epoch 8: loss = 4.0510
Epoch 9: loss = 3.8653
Epoch 10: loss = 3.6871
Epoch 11: loss = 3.5092
Epoch 12: loss = 3.3446
Epoch 13: loss = 3.1824
Epoch 14: loss = 3.0175
Epoch 15: loss = 2.8543
Epoch 16: loss = 2.6904
Epoch 17: loss = 2.5345
Epoch 18: loss = 2.3772
Epoch 19: loss = 2.2207
Epoch 20: loss = 2.0578
Epoch 21: loss = 1.9179
Epoch 22: loss = 1.7683
Epoch 23: loss = 1.6297
Epoch 24: loss = 1.4905
Epoch 25: loss = 1.3649
Epoch 26: loss = 1.2508
Epoch 27: loss = 1.1394
Epoch 28: loss = 1.0442
Epoch 29: loss = 0.9531
Epoch 30: loss = 0.8691
Epoch 31: loss = 0.7987
Epoch 32: loss = 0.7258
Epoch 33: loss = 0.6578
Epoch 34: loss = 0.6020
Epoch 35: loss = 0.5512
Epoch 36: loss = 0.5139
Epoch 37: loss = 0.4818
Epoch 38: loss = 0.4526
Epoch 39: loss = 0.4251
Epoch 40: loss = 0.3996


In [10]:
save_path = "../models/transformer/checkpoints/transformer_tinystories.pt"
torch.save({
    "model_state_dict": transformer.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "epoch": epoch + 1,
    "loss": total_loss / len(train_loader),
    "config": {
        "vocab_size": transformer.output_linear.out_features,
        "d_model": transformer.d_model,
        "pad_idx": transformer.pad_idx,
    }
}, save_path)

In [11]:
checkpoint = torch.load("../models/transformer/checkpoints/transformer_tinystories.pt", map_location=device)

transformer = Transformer(
    vocab_size=checkpoint["config"]["vocab_size"],
    d_model=checkpoint["config"]["d_model"],
    n_heads=n_heads,
    d_ff=d_ff,
    n_encoder_layers=n_encoder_layers,
    n_decoder_layers=n_decoder_layers,
    max_len=max_len,
    pad_idx=checkpoint["config"]["pad_idx"],
    dropout=dropout
).to(device)

transformer.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [12]:
print(generate_story_end(transformer, tokenizer, "I was walking with my friend", device))

 around, they found a big pond. The boy was so happy that he had lots of fun. He showed him to stay and make sure he would take care of fun. The moral of the story is to stay connected. Together they both with the story is to stay and get more. When the story is to go home, the story is to stay connected. He thanked the story is to stay and even more and even more careful. Together they both with their cool and even more careful when they were safe and even more. From then on, the best birthday ever after. The moral of friends for helping me, the story is to always very happy and even more. Every week, the story is to always listen to explore together until they were very proud of friends. They had made sure that being different things in the story is to explore together until they both with their moms and even more. When they learned that being different things in the story is to explore together again. And the story is


In [13]:
print(generate_story_end(transformer, tokenizer, "Once upon a time, there was a graceful cat named Kitty. She loved to play with her ball and jump around the house. Kitty was very happy and liked to make her friends laugh. One day, Kitty saw a tap in the garden.", device))

 began to follow the friends. The rocket felt happy and smooth the cloth. She watched the rocket was so happy. The rocket thanked the rocket and gave it a big hug. Finally the rocket was time to go home. The little girl's friends. She knew that she had been creative enough to go home. They would always visit the moon. And the moon for the moon for hours and smiled and hugged the rocket was sure to share. And the cloth. And the rocket was always visit the rocket was always visit the rocket and smiled as a wonderful time to share for hours and thanked the moon. The end. It was always visit the rocket and smiled as a wonderful time to share. It was always visit the rocket and smiled at her journey and said goodbye to have good day. It was always visit the wonderful time to have good day. The old man for the wonderful time to have never forget it was always visit the wonderful time to have never forget it was always visit. The old man.


In [14]:
print(generate_story_end(transformer, tokenizer, "Once upon a time, there was a naughty bee named Buzzy. Buzzy loved to fly around the big tree and play with the other bees. One day, Buzzy saw a little girl named Lucy sitting on the grass. Buzzy flew down and said", device))

 began to the sky. The girl felt frustrated. She watched as she could get closer. As she heard a big, the peopleâ€™s house. The little girl was so happy. She gave the tree was so proud that she had a big hug and said. And the best of friends. And from then on, the park, and were happy and went on, the little girl had a big hug. And the best of friends. And the park was happy and joy, she would come back. And the park was happy and were happy and were happy again. And the park was happy and were happy again. And the best of friends had a wonderful adventure. And the best day on, she would always visit to have lots of friends. And the best day on, she would always visit to have lots of friends. And the most beautiful and share. And the best of friends. And the best of friends. And the most beautiful and share and share and share. And the best day
