In [None]:
import logging

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S'
)
log = logging.getLogger(__name__)
log.info("Log initialized.")

In [None]:
from utils.data_loader import MINDDataset
from torch.utils.data import DataLoader

data_dir = os.getenv("HOME") + "/data/MINDsmall"
train_ds = MINDDataset(data_dir, split='train')
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4)

In [None]:
from utils.model import GPTRec
import torch.nn as nn
from torch.optim import AdamW

model = GPTRec(vocab_size=len(train_ds.nid2idx)).to(device)
optimizer = AdamW(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

In [None]:
for epoch in range(5):
    tot_loss, tot_acc, tot_samples = 0, 0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for batch in pbar:
        # Filter out the data with positive samples in batch
        pos_hist, pos_target = [], []
        for h, imp, lab in zip(batch['history'], batch['impressions'], batch['labels']):
            if sum(lab) > 0:
                pos_idx = lab.index(1)
                pos_hist.append(h)
                pos_target.append(imp[pos_idx])
        if len(pos_hist) == 0:
            continue

        # Construct input_ids, mask, labels
        # input_ids: [CLS] history [SEP]
        # all sequences are padded to the same length
        max_len = max(len(h) for h in pos_hist)
        input_ids = torch.zeros(len(pos_hist), max_len + 2, dtype=torch.long)
        for i, h in enumerate(pos_hist):
            input_ids[i, 0] = model.cls_token_id
            input_ids[i, 1:1+len(h)] = torch.tensor(h, dtype=torch.long)
            input_ids[i, 1+len(h)] = model.sep_token_id

        input_ids = input_ids.to(device)
        mask = (input_ids != 0).to(torch.long).to(device)
        labels = torch.tensor(pos_target, dtype=torch.long).to(device)

        logits = model(input_ids, attention_mask=mask)
        loss = loss_fn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred = logits.argmax(dim=-1)
        acc = (pred == labels).float().sum().item()

        tot_loss += loss.item() * len(pos_hist)
        tot_acc += acc
        tot_samples += len(pos_hist)

        pbar.set_postfix(loss=tot_loss/tot_samples, acc=tot_acc/tot_samples)

    log.info(f"Epoch {epoch}  Loss {tot_loss/tot_samples:.4f}  Acc {tot_acc/tot_samples:.4f}")

log.info("Training finished. Saving model...")
torch.save(model.state_dict(), "baseline.pt")
log.info("Model saved to baseline.pt")

In [None]:
# Eval TODO