In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import PreTrainedTokenizerFast
from dataset import AbstractiveSummarizationDataset, TokenizeCollate
from decoder import DecoderOnlyModel

In [None]:
EPOCHS = 20
LR = 0.0001
BATCH_SIZE = 32
DEV = torch.device("mps")

D_MODEL = 512
NUM_HEADS = 8
NUM_LAYERS = 5
COMPRESS_FACTOR = 3
BLOCK_SIZE = 256
DROPOUT_P = 0.1

In [None]:
dataset = AbstractiveSummarizationDataset("xsum.csv")
tokenizer = PreTrainedTokenizerFast.from_pretrained("tokenizer")
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TokenizeCollate(tokenizer))

In [None]:
model = DecoderOnlyModel(
    vocab_size=tokenizer.vocab_size,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    compress_factor=COMPRESS_FACTOR,
    block_size=BLOCK_SIZE,
    dropout_p=DROPOUT_P,
).to(DEV)

In [None]:
crit = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=LR)

In [None]:
for e in range(1, EPOCHS + 1):
    loop = tqdm(enumerate(loader), total=len(loader), leave=True, position=0)
    loop.set_description(f"Epoch : [{e}/{EPOCHS}]")
    total_loss = 0
    for i, (src, tgt) in loop:
        src, tgt = src.to(DEV), tgt.to(DEV)
        yhat = model(src)
        loss = crit(yhat.reshape(-1, yhat.shape[-1]), tgt.reshape(-1))
        loss.backward()
        opt.step()

        total_loss += loss.item()
        loop.set_postfix(loss = total_loss/(i + 1))