In [2]:
from model import TransformerDecoder

from dataset import TextDataset

%load_ext autoreload
%autoreload 2

In [11]:
# we want to train our tokenizer on the whole dataset

VOCAB_SIZE = 3000
MAX_LENGTH = 512


text_dataset = TextDataset(data_file="stories.txt",
                           train=True,
                           sp_model_prefix="sp_model",
                           vocab_size=VOCAB_SIZE,
                           max_length=MAX_LENGTH)

In [4]:

# !python3 parse_stories.py

In [5]:
# !head -n 128 short_stories.txt > train.txt
# !tail -n 10 short_stories.txt > test.txt

# !head -n 128 stories.txt > large_train.txt
# !tail -n 64 stories.txt > large_test.txt

In [16]:
import torch
from trainloop import train
from torch.utils.data import DataLoader
from utils import collate_fn

BATCH_SIZE = 32
NUM_EPOCHS = 600

train_set = TextDataset(data_file="large_train.txt", 
                        train=True, 
                        sp_model_prefix="sp_model", 
                        vocab_size=VOCAB_SIZE, 
                        max_length=MAX_LENGTH)
valid_set = TextDataset(data_file="large_train.txt", 
                        train=True, 
                        sp_model_prefix="sp_model", 
                        vocab_size=VOCAB_SIZE,
                        max_length=MAX_LENGTH)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)
val_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [17]:
train_set[1]

(tensor([   1,  136,  141,    6,  115,  118,  419,    6,  656,   31,  816,   13,
          613, 3983,  260,   70,   13,  613,  266,  201,   13,    8,  656,   15,
         3949,  440,   13,    8,  741, 3983,  104,   70,   13,  613,  159,   13,
            8,  656,   15,  966,   89,    8,  741,  168,  135, 1168, 3983,   20,
          159,   13,    8, 2723,   15,  263,   13, 3949,    8,  741,   75, 3533,
         3989,  507,  440, 3983, 3967,    3, 1092,  262,    8,  741,  102, 1168,
         3997,    3,   20,  221, 1085, 3983,  420,  239,    6,  113,  179,  816,
         1064,  259,  833,  528, 3983,   13,  613,  221, 1064,  980,    8,  741,
          168,  102, 1168, 3983, 3967,    3,   18,    3, 3976,  336,  142,   30,
          292,   75,  235, 3989,    3,   78, 1064, 3983,   44,  117, 1173, 2966,
            8,  741,   15,  928,  235,  436,  468,   12,   27,  338,   63, 3976,
         3983,   13,  613,  307,   15,   78, 3989, 3967,    3,   18,  198,  737,
           13,  454,    8,  

In [18]:
import wandb
wandb.login()

wandb.init(project="LLM-Homework")

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TransformerDecoder(vocab_size=text_dataset.vocab_size,
                            embed_dim=64,
                            n_blocks=2,
                            n_head=2,
                            ff_dim=64,
                            text_dataset=text_dataset,)
print(f"Number of parameters in the model: {sum(p.numel() for p in model.parameters())}")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.98), eps=1e-9)
criterion = torch.nn.CrossEntropyLoss(ignore_index=text_dataset.pad_id)

train_losses, test_losses = train(model=model, optimizer=optimizer, criterion=criterion,
                                  train_loader=train_loader, test_loader=val_loader, num_epochs=NUM_EPOCHS)

Number of parameters in the model: 566432






0,1
test_loss,█▇▆▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
training_loss,█▇▆▅▅▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁

0,1
test_loss,1.56721
training_loss,1.82705


In [27]:
model.generate(text_dataset, prompt="Once upon a time there was a", max_len=50)

torch.Size([1, 9]) torch.Size([1, 1])
torch.Size([1, 10]) torch.Size([1, 1])
torch.Size([1, 11]) torch.Size([1, 1])
torch.Size([1, 12]) torch.Size([1, 1])
torch.Size([1, 13]) torch.Size([1, 1])
torch.Size([1, 14]) torch.Size([1, 1])
torch.Size([1, 15]) torch.Size([1, 1])
torch.Size([1, 16]) torch.Size([1, 1])
torch.Size([1, 17]) torch.Size([1, 1])
torch.Size([1, 18]) torch.Size([1, 1])
torch.Size([1, 19]) torch.Size([1, 1])
torch.Size([1, 20]) torch.Size([1, 1])
torch.Size([1, 21]) torch.Size([1, 1])
torch.Size([1, 22]) torch.Size([1, 1])
torch.Size([1, 23]) torch.Size([1, 1])
torch.Size([1, 24]) torch.Size([1, 1])
torch.Size([1, 25]) torch.Size([1, 1])
torch.Size([1, 26]) torch.Size([1, 1])
torch.Size([1, 27]) torch.Size([1, 1])
torch.Size([1, 28]) torch.Size([1, 1])
torch.Size([1, 29]) torch.Size([1, 1])
torch.Size([1, 30]) torch.Size([1, 1])
torch.Size([1, 31]) torch.Size([1, 1])
torch.Size([1, 32]) torch.Size([1, 1])
torch.Size([1, 33]) torch.Size([1, 1])
torch.Size([1, 34]) torch.

'once upon a time there was a upon a time, there was a bird lived in the boat was a lot of her mombug traffic all day, something around the village him. it was a tree. the wild rabbit knew the itch his toy, it was too there was'