In [6]:
%load_ext autoreload
%autoreload 2

In [1]:
import torch
from torch.utils.data import DataLoader
from timeit import default_timer as timer
from tqdm.notebook import tqdm

from tinystories import *
from llama import LLAMA

In [11]:
train_ds = TinyStoriesDataset(data_file="./archive/TinyStoriesV3-GPT4-train.txt", sp_model_prefix="/tokenizer")

In [None]:
val_ds = TinyStoriesDataset(data_file="./archive/TinyStoriesV3-GPT4-valid.txt", sp_model_prefix="")

In [7]:
def evaluate(model, dataloader, loss_fn, pad_idx, device):
    model.eval()
    losses = 0

    for tgt, length in tqdm(dataloader):
        tgt = tgt.to(device)
        tgt_input = tgt[:-1, :]
        tgt_mask, tgt_padding_mask = create_mask(tgt_input, pad_idx, device)
        logits = model(tgt_input, tgt_mask, tgt_padding_mask)
        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()
    return losses / len(dataloader)


def train(n_epochs, model, pad_idx, optimizer, train_loader, val_loader, device, evaluation_step=4000):
    loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_idx)

    for epoch in range(1, n_epochs + 1):
        model.train()
        losses = 0

        for i, (tgt, length) in tqdm(enumerate(train_loader)):
            tgt = tgt.to(device)
            tgt_input = tgt[:-1, :]
            tgt_mask, tgt_padding_mask = create_mask(tgt_input, pad_idx, device)
            logits = model(tgt_input, tgt_mask, tgt_padding_mask)
            optimizer.zero_grad()
            tgt_out = tgt[1:, :]
            loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            loss.backward()
            optimizer.step()
            losses += loss.item()

            if i % evaluation_step == 0:
                val_loss = evaluate(model, val_loader, loss_fn, pad_idx, device)
                print((f"Epoch: {epoch}, Train loss: {(losses / evaluation_step):.3f}, Val loss: {val_loss:.3f}"))
                losses = 0

        val_loss = evaluate(model, val_loader, loss_fn, pad_idx, device)
        print((f"Epoch: {epoch}, Train loss: {(losses / (len(train_loader) % evaluation_step)):.3f}, Val loss: {val_loss:.3f}"))