In [None]:
import torch
from torch import nn, optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from dataset import QQPDataset, TokenizeCollate
from utils import init_finetuning_model_and_tokenizer, inference

In [None]:
DEV = torch.device("mps")
BATCH_SIZE = 32
LR = 6.25e-5
EPOCHS = 5
LAMBDA = 0.5

WEIGHTS_PATH = "weights.pth"
DATASET_PATH = "dataset/train.csv"

In [None]:
tokenizer, model = init_finetuning_model_and_tokenizer(WEIGHTS_PATH, DEV)
dataset = QQPDataset(DATASET_PATH)
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=TokenizeCollate(tokenizer))

crit = nn.BCEWithLogitsLoss()
opt = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.LinearLR(opt, start_factor=LAMBDA, end_factor=0.95, total_iters=2000)

In [None]:
for e in range(1, EPOCHS + 1):
    loop = tqdm(enumerate(loader), total=len(loader), leave=True, position=0)
    loop.set_description(f"Epoch : [{e}/{EPOCHS}]")
    total_loss = 0
    for i, (x1, x2, x1_mask, x2_mask, labels) in loop:
        x1, x2, x1_mask, x2_mask, labels = x1.to(DEV), x2.to(DEV), x1_mask.to(DEV), x2_mask.to(DEV), labels.to(DEV)
        opt.zero_grad()
        yhat = model(x1, x2, x1_mask, x2_mask)
        loss = crit(yhat, labels)
        loss.backward()
        opt.step()
        scheduler.step()

        total_loss += loss.item()
        loop.set_postfix(loss = total_loss / (i + 1))

In [None]:
q1 = "What fraction is a quarter?"
q2 = "What is a balanced diet?"

inference(q1, q2, model.eval(), tokenizer, DEV)