In [None]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!git clone https://github.com/ItWasAllYellow/public_cs224n_gpt.git
%cd public_cs224n_gpt

Cloning into 'public_cs224n_gpt'...
remote: Enumerating objects: 69, done.[K
remote: Counting objects: 100% (38/38), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 69 (delta 21), reused 13 (delta 13), pack-reused 31 (from 1)[K
Receiving objects: 100% (69/69), 30.87 MiB | 29.17 MiB/s, done.
Resolving deltas: 100% (22/22), done.
/content/public_cs224n_gpt


In [None]:
"""
Answer Margin Ranking Trainer

Trains a GPT-2 based model to rank multiple-choice answers by minimizing margin loss
on perplexity scores. Generates Top-1/2/3 accuracy and average perplexity on dev/test splits.
"""

import argparse
import os
import random
import sys

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import GPT2Tokenizer
from models.gpt2 import GPT2Model      # public_cs224n_gpt implementation
from transformers import GPT2LMHeadModel

# suppress TQDM if needed
TQDM_DISABLE = False

def seed_everything(seed: int = 42):
    """Fix random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class AnswerMarginModel(nn.Module):
    """Wrapper around GPT2Model + LM head for computing token logits."""
    def __init__(self, args):
        super().__init__()
        # backbone
        self.gpt = GPT2Model.from_pretrained(
            model=args.model_size,
            d=args.d,
            l=args.l,
            num_heads=args.num_heads
        )
        # LM head
        self.lm_head = nn.Linear(args.d, GPT2Tokenizer.from_pretrained(args.model_size).vocab_size, bias=False)
        # fine-tune everything
        for p in self.gpt.parameters():
            p.requires_grad = True
        for p in self.lm_head.parameters():
            p.requires_grad = True

    def forward(self, input_ids, attention_mask):
        """Returns token-level logits of shape (B, T, V)."""
        out = self.gpt(input_ids=input_ids, attention_mask=attention_mask)
        hidden = out["last_hidden_state"]
        logits = self.lm_head(hidden)
        return logits

def calculate_perplexity(logits, labels, attention_mask):
    """
    Compute per-example perplexity:
    - logits: (B*T, V)
    - labels & mask: (B*T)
    Returns perp of shape (B*T) then can be reshaped.
    """
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    shift_mask   = attention_mask[..., 1:].contiguous()

    loss_fn = nn.CrossEntropyLoss(reduction="none")
    loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)),
                   shift_labels.view(-1))
    loss = loss.view_as(shift_labels)
    loss = loss * shift_mask

    sum_loss = loss.sum(dim=1)
    token_count = shift_mask.sum(dim=1).clamp_min(1)
    mean_loss = sum_loss / token_count
    return torch.exp(mean_loss)

class MarginLoss(nn.Module):
    """Hinge-style margin loss between correct and incorrect PPLs."""
    def __init__(self, margin: float = 1.0):
        super().__init__()
        self.margin = margin

    def forward(self, correct_ppl, incorrect_ppls):
        # correct: (B,), incorrect: (B, C-1)
        correct = correct_ppl.unsqueeze(1).expand_as(incorrect_ppls)
        loss = F.relu(correct - incorrect_ppls + self.margin)
        return loss.mean()

class AnswerMarginDataset(Dataset):
    """Dataset for packing 5 candidate texts per example."""
    def __init__(self, df: pd.DataFrame, tokenizer: GPT2Tokenizer, max_length: int = 512):
        self.df = df.reset_index(drop=True)
        self.tok = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        prefix = row["prefix"]
        suffix = row["suffix"]
        choices = [row[f"choice_{i}"] for i in range(1,6)]
        example_texts = [f"{prefix}{c}{suffix}" for c in choices]
        label = int(row["answer"]) - 1
        return example_texts, label

    def collate_fn(self, batch):
        texts, labels = zip(*batch)
        flat_texts = sum(texts, [])
        enc = self.tok.batch_encode_plus(
            flat_texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        B = len(batch)
        input_ids = enc["input_ids"].view(B, 5, -1)
        attention_mask = enc["attention_mask"].view(B, 5, -1)
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": torch.tensor(labels, dtype=torch.long)
        }

def load_data(path: str, split: str = None):
    """Load CSV and optionally filter by split column."""
    df = pd.read_csv(path, encoding="utf-8-sig")
    return df[df.split == split].reset_index(drop=True) if split else df

def save_model(model: nn.Module, optimizer, args, path: str):
    """Persist the model + optimizer state."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        "model": model.state_dict(),
        "optim": optimizer.state_dict(),
        "args": args
    }, path)
    print(f"Saved model to {path}")

def evaluate(dataloader, model, lm_head_model, device):
    """Compute Top-1/2/3 accuracies and average PPL on a split."""
    model.eval()
    total = top1 = top2 = top3 = 0
    total_ppl = 0.0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", disable=TQDM_DISABLE):
            B, C, L = batch["input_ids"].shape
            ids = batch["input_ids"].view(B*C, L).to(device)
            mask = batch["attention_mask"].view(B*C, L).to(device)
            logits = model(ids, mask)
            ppl = calculate_perplexity(logits, ids, mask).view(B, C)

            total += B
            ranked = torch.argsort(ppl, dim=1)
            labels = batch["labels"].to(device)
            top1 += (ranked[:,0] == labels).sum().item()
            top2 += ((ranked[:,:2] == labels.unsqueeze(1)).any(1)).sum().item()
            top3 += ((ranked[:,:3] == labels.unsqueeze(1)).any(1)).sum().item()
            total_ppl += ppl.mean().item() * B

    return top1/total, top2/total, top3/total, total_ppl/total

def train(args):
    """Main training loop for Margin Ranking."""
    seed_everything(args.seed)
    device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # prepare data
    train_df = load_data(args.data_path, "train")
    dev_df   = load_data(args.data_path, "valid")

    tokenizer = GPT2Tokenizer.from_pretrained(args.model_size)
    tokenizer.pad_token = tokenizer.eos_token

    train_ds = AnswerMarginDataset(train_df, tokenizer, args.max_length)
    dev_ds   = AnswerMarginDataset(dev_df, tokenizer, args.max_length)

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                              collate_fn=train_ds.collate_fn)
    dev_loader   = DataLoader(dev_ds, batch_size=args.batch_size, shuffle=False,
                              collate_fn=dev_ds.collate_fn)

    model = AnswerMarginModel(args).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    loss_fn = MarginLoss(margin=args.margin)

    best_top1 = 0.0
    for epoch in range(1, args.epochs+1):
        model.train()
        epoch_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Train Epoch {epoch}", disable=TQDM_DISABLE):
            optimizer.zero_grad()
            B, C, L = batch["input_ids"].shape
            ids = batch["input_ids"].view(B*C, L).to(device)
            mask = batch["attention_mask"].view(B*C, L).to(device)
            logits = model(ids, mask)
            ppl = calculate_perplexity(logits, ids, mask).view(B, C)
            correct = ppl[torch.arange(B), batch["labels"].to(device)]
            wrongs = torch.stack([ppl[i, torch.arange(C)!=batch["labels"][i]] for i in range(B)])
            loss = loss_fn(correct, wrongs)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_loader)
        top1, top2, top3, avg_ppl = evaluate(dev_loader, model, None, device)
        print(f"Epoch {epoch} | Loss: {avg_loss:.4f} | Dev Top1/2/3: {top1:.4f}/{top2:.4f}/{top3:.4f} | PPL: {avg_ppl:.4f}")

        # checkpoint
        if top1 > best_top1:
            best_top1 = top1
            save_model(model, optimizer, args, args.save_path)

def get_args():
    """Parse command-line arguments (ignores Jupyter args)."""
    p = argparse.ArgumentParser()
    p.add_argument("--data_path",    type=str, default="/content/drive/MyDrive/CSEG321/dataset/answer_margin.csv")
    p.add_argument("--model_size",   type=str, default="gpt2", choices=["gpt2","gpt2-medium","gpt2-large"])
    p.add_argument("--batch_size",   type=int, default=4)
    p.add_argument("--max_length",   type=int, default=512)
    p.add_argument("--lr",           type=float, default=5e-5)
    p.add_argument("--epochs",       type=int, default=20)
    p.add_argument("--margin",       type=float, default=1000.0)
    p.add_argument("--seed",         type=int, default=42)
    p.add_argument("--use_gpu",      action="store_true")
    # Jupyter friendly
    args = p.parse_args([])
    if torch.cuda.is_available(): args.use_gpu = True
    # set model dims
    if args.model_size == "gpt2":
        args.d, args.l, args.num_heads = 768,12,12
    elif args.model_size == "gpt2-medium":
        args.d, args.l, args.num_heads = 1024,24,16
    else:
        args.d, args.l, args.num_heads = 1280,36,20
    # save path
    args.save_path = f"/content/drive/MyDrive/CSEG321/models/answer_margin_{args.model_size}_{args.epochs}e_{args.lr}lr.pt"
    return args

if __name__ == "__main__":
    args = get_args()
    train(args)

Using device: cuda


Train Epoch 1: 100%|██████████| 36/36 [00:22<00:00,  1.58it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.89it/s]


Epoch 1 | Loss: 655292.4014 | Dev Top1/2/3: 0.2222/0.5000/0.6111 | PPL: 410022.9861
Saved model to /content/drive/MyDrive/CSEG321/models/answer_margin_gpt2_20e_5e-05lr.pt


Train Epoch 2: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


Epoch 2 | Loss: 14616.7176 | Dev Top1/2/3: 0.2222/0.5000/0.6667 | PPL: 324473.6424


Train Epoch 3: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.96it/s]


Epoch 3 | Loss: 10014.4197 | Dev Top1/2/3: 0.1667/0.4444/0.6667 | PPL: 287841.7986


Train Epoch 4: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


Epoch 4 | Loss: 6158.5101 | Dev Top1/2/3: 0.1667/0.4444/0.6111 | PPL: 266362.0955


Train Epoch 5: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.99it/s]


Epoch 5 | Loss: 7479.2651 | Dev Top1/2/3: 0.1667/0.4444/0.6667 | PPL: 244116.0069


Train Epoch 6: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.92it/s]


Epoch 6 | Loss: 6509.7757 | Dev Top1/2/3: 0.1667/0.3889/0.6111 | PPL: 224857.0226


Train Epoch 7: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.96it/s]


Epoch 7 | Loss: 5731.4290 | Dev Top1/2/3: 0.1667/0.4444/0.5556 | PPL: 207631.5660


Train Epoch 8: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


Epoch 8 | Loss: 4531.2536 | Dev Top1/2/3: 0.1667/0.4444/0.5556 | PPL: 195381.7326


Train Epoch 9: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


Epoch 9 | Loss: 5592.2475 | Dev Top1/2/3: 0.2222/0.4444/0.5556 | PPL: 182302.9340


Train Epoch 10: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


Epoch 10 | Loss: 5107.2027 | Dev Top1/2/3: 0.2222/0.4444/0.5556 | PPL: 169638.4948


Train Epoch 11: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


Epoch 11 | Loss: 4476.1131 | Dev Top1/2/3: 0.2222/0.4444/0.5556 | PPL: 158921.4809


Train Epoch 12: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


Epoch 12 | Loss: 3964.3289 | Dev Top1/2/3: 0.2222/0.4444/0.5556 | PPL: 150580.7370


Train Epoch 13: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


Epoch 13 | Loss: 4219.7947 | Dev Top1/2/3: 0.2222/0.4444/0.5556 | PPL: 141303.5443


Train Epoch 14: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


Epoch 14 | Loss: 3073.9286 | Dev Top1/2/3: 0.2222/0.4444/0.5556 | PPL: 136156.4844


Train Epoch 15: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


Epoch 15 | Loss: 3942.9524 | Dev Top1/2/3: 0.2222/0.4444/0.5556 | PPL: 129626.9080


Train Epoch 16: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


Epoch 16 | Loss: 2915.5466 | Dev Top1/2/3: 0.2222/0.4444/0.6111 | PPL: 124414.7873


Train Epoch 17: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


Epoch 17 | Loss: 2727.9395 | Dev Top1/2/3: 0.2222/0.3889/0.7222 | PPL: 120183.7587


Train Epoch 18: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


Epoch 18 | Loss: 2547.3020 | Dev Top1/2/3: 0.2222/0.3889/0.7222 | PPL: 115745.5825


Train Epoch 19: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


Epoch 19 | Loss: 2866.0914 | Dev Top1/2/3: 0.2222/0.4444/0.7222 | PPL: 111889.1415


Train Epoch 20: 100%|██████████| 36/36 [00:22<00:00,  1.60it/s]
Evaluating: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]

Epoch 20 | Loss: 2305.1473 | Dev Top1/2/3: 0.2222/0.4444/0.7222 | PPL: 108137.4618





In [None]:
"""
NLL-Ranking Trainer

Trains a GPT-2 model to pick the correct option by minimizing cross-entropy
on Negative-Log-Likelihood scores. Top-1/2/3 accuracy + average NLL.
"""

import argparse, os, random, sys, math
import numpy as np, pandas as pd, torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import GPT2Tokenizer
from models.gpt2 import GPT2Model     # custom GPT-2 implementation

TQDM_DISABLE = False  # set True in Colab if bar flickers

# --------------------------------------------------------------------- #
# 0. Utils
# --------------------------------------------------------------------- #
def seed_everything(seed: int = 42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --------------------------------------------------------------------- #
# 1. Dataset
# --------------------------------------------------------------------- #
class AnswerDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: GPT2Tokenizer, max_len=512):
        self.df, self.tok, self.max_len = df.reset_index(drop=True), tokenizer, max_len

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        choices = [row[f"choice_{i}"] for i in range(1, 6)]
        texts   = [f"{row.prefix}{c}{row.suffix}" for c in choices]
        label   = int(row.answer) - 1
        return texts, label

    def collate_fn(self, batch):
        texts, labels  = zip(*batch)
        flat = sum(texts, [])          # (B*5,) list
        enc  = self.tok(
            flat, padding="max_length", truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )
        B = len(batch)
        return {
            "input_ids":     enc["input_ids"].view(B, 5, -1),
            "attention_mask":enc["attention_mask"].view(B, 5, -1),
            "labels":        torch.tensor(labels, dtype=torch.long)
        }

# --------------------------------------------------------------------- #
# 2. Model: GPT-2 + LM-head to compute NLL
# --------------------------------------------------------------------- #
class NLLRankModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.gpt = GPT2Model.from_pretrained(
            model=args.model_size, d=args.d, l=args.l, num_heads=args.num_heads
        )
        vocab = GPT2Tokenizer.from_pretrained(args.model_size).vocab_size
        self.lm_head = nn.Linear(args.d, vocab, bias=False)
        for p in self.parameters(): p.requires_grad = True

    def forward(self, input_ids, attention_mask):
        """
        input_ids: (B, C, L)   attention_mask: (B, C, L)
        returns nll: (B, C)  -- lower is better
        """
        B, C, L = input_ids.shape
        flat_ids  = input_ids.view(B*C, L)
        flat_mask = attention_mask.view(B*C, L)

        h = self.gpt(flat_ids, attention_mask=flat_mask)["last_hidden_state"]
        logits = self.lm_head(h)                       # (B*C, L, V)

        # shift for LM loss
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = flat_ids[:, 1:].contiguous()
        shift_mask   = flat_mask[:, 1:].contiguous()

        loss_ = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            reduction="none"
        ).view(flat_ids.size(0), -1) * shift_mask

        # token average NLL
        nll = loss_.sum(1) / shift_mask.sum(1).clamp_min(1)   # (B*C,)
        return nll.view(B, C)                                 # (B,5)

# --------------------------------------------------------------------- #
# 3. Train / Eval
# --------------------------------------------------------------------- #
def evaluate(loader, model, device):
    model.eval()
    tot = top1 = top2 = top3 = 0; tot_nll = 0.0
    with torch.no_grad():
        for batch in tqdm(loader, disable=TQDM_DISABLE, desc="Eval"):
            ids, mask = batch["input_ids"].to(device), batch["attention_mask"].to(device)
            nll = model(ids, mask)                 # (B, 5)
            scores = -nll                          # higher better
            rank = scores.argsort(dim=1, descending=True)
            y = batch["labels"].to(device)
            B = y.size(0)
            top1 += (rank[:, 0] == y).sum().item()
            top2 += ((rank[:, :2] == y.unsqueeze(1)).any(1)).sum().item()
            top3 += ((rank[:, :3] == y.unsqueeze(1)).any(1)).sum().item()
            tot  += B
            tot_nll += nll.mean().item()*B
    return top1/tot, top2/tot, top3/tot, tot_nll/tot

def train(args):
    seed_everything(args.seed)
    dev = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu")

    print("Device:", dev)

    # load data
    df  = pd.read_csv(args.data_path, encoding="utf-8-sig")
    trn = df[df.split=="train"]; val = df[df.split=="valid"]

    tok = GPT2Tokenizer.from_pretrained(args.model_size); tok.pad_token = tok.eos_token
    train_ds = AnswerDataset(trn, tok, args.max_length); val_ds = AnswerDataset(val, tok, args.max_length)
    tr_loader= DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,  collate_fn=train_ds.collate_fn)
    val_loader=DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False, collate_fn=val_ds.collate_fn)

    model = NLLRankModel(args).to(dev)
    optim = torch.optim.AdamW(model.parameters(), lr=args.lr)
    best = 0.0

    for ep in range(1, args.epochs+1):
        model.train(); ep_loss = 0.0
        for batch in tqdm(tr_loader, disable=TQDM_DISABLE, desc=f"Train {ep}"):
            optim.zero_grad()
            ids, mask = batch["input_ids"].to(dev), batch["attention_mask"].to(dev)
            nll = model(ids, mask)           # (B,5)
            loss = F.cross_entropy(-nll, batch["labels"].to(dev))
            loss.backward(); optim.step()
            ep_loss += loss.item()
        print(f"Epoch {ep}  loss={ep_loss/len(tr_loader):.4f}")

        t1,t2,t3,av_nll = evaluate(val_loader, model, dev)
        print(f"  Valid Top1/2/3={t1:.3f}/{t2:.3f}/{t3:.3f}  AvgNLL={av_nll:.3f}")
        if t1>best: best=t1; save(model, optim, args)

def save(m, o, args):
    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
    torch.save({"model":m.state_dict(),"optim":o.state_dict(),"args":args}, args.save_path)
    print("Saved to", args.save_path)

def get_args():
    p=argparse.ArgumentParser();  # identical to original
    p.add_argument("--data_path", type=str, default="/content/drive/MyDrive/CSEG321/dataset/answer_margin.csv")
    p.add_argument("--model_size",type=str, default="gpt2", choices=["gpt2","gpt2-medium","gpt2-large"])
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--max_length", type=int, default=512)
    p.add_argument("--lr",        type=float, default=5e-5)
    p.add_argument("--epochs",    type=int, default=20)
    p.add_argument("--seed",      type=int, default=42)
    p.add_argument("--use_gpu",   action="store_true")
    # Jupyter friendly
    args = p.parse_args([])
    if torch.cuda.is_available(): args.use_gpu = True
    if args.model_size=="gpt2":    args.d,args.l,args.num_heads=768,12,12
    elif args.model_size=="gpt2-medium": args.d,args.l,args.num_heads=1024,24,16
    else: args.d,args.l,args.num_heads=1280,36,20
    args.save_path=f"/content/drive/MyDrive/CSEG321/models/answer_nll_{args.model_size}_{args.epochs}e.pt"
    return args

if __name__=="__main__":
    train(get_args())


Device: cuda


Train 1: 100%|██████████| 36/36 [00:22<00:00,  1.57it/s]


Epoch 1  loss=1.6539


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.88it/s]


  Valid Top1/2/3=0.278/0.667/0.778  AvgNLL=17.481
Saved to /content/drive/MyDrive/CSEG321/models/answer_nll_gpt2_20e.pt


Train 2: 100%|██████████| 36/36 [00:22<00:00,  1.58it/s]


Epoch 2  loss=1.6049


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.94it/s]


  Valid Top1/2/3=0.333/0.611/0.833  AvgNLL=16.995
Saved to /content/drive/MyDrive/CSEG321/models/answer_nll_gpt2_20e.pt


Train 3: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 3  loss=1.5719


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


  Valid Top1/2/3=0.389/0.556/0.833  AvgNLL=17.003
Saved to /content/drive/MyDrive/CSEG321/models/answer_nll_gpt2_20e.pt


Train 4: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 4  loss=1.5010


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.93it/s]


  Valid Top1/2/3=0.333/0.556/0.778  AvgNLL=20.106


Train 5: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 5  loss=1.5038


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.96it/s]


  Valid Top1/2/3=0.389/0.611/0.778  AvgNLL=22.044


Train 6: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 6  loss=1.4081


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


  Valid Top1/2/3=0.333/0.667/0.889  AvgNLL=22.764


Train 7: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 7  loss=1.2752


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


  Valid Top1/2/3=0.500/0.667/0.833  AvgNLL=23.807
Saved to /content/drive/MyDrive/CSEG321/models/answer_nll_gpt2_20e.pt


Train 8: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 8  loss=1.1716


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


  Valid Top1/2/3=0.444/0.778/0.833  AvgNLL=25.030


Train 9: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 9  loss=1.2209


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


  Valid Top1/2/3=0.500/0.722/0.889  AvgNLL=26.052


Train 10: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 10  loss=1.0447


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


  Valid Top1/2/3=0.444/0.722/0.778  AvgNLL=27.357


Train 11: 100%|██████████| 36/36 [00:22<00:00,  1.58it/s]


Epoch 11  loss=0.9033


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


  Valid Top1/2/3=0.333/0.778/0.889  AvgNLL=28.611


Train 12: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 12  loss=0.8590


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


  Valid Top1/2/3=0.389/0.611/0.722  AvgNLL=34.321


Train 13: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 13  loss=0.7393


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


  Valid Top1/2/3=0.278/0.667/0.778  AvgNLL=31.605


Train 14: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 14  loss=0.7859


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


  Valid Top1/2/3=0.389/0.667/0.722  AvgNLL=33.293


Train 15: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 15  loss=0.6568


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


  Valid Top1/2/3=0.444/0.611/0.889  AvgNLL=34.517


Train 16: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 16  loss=0.5888


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


  Valid Top1/2/3=0.389/0.611/0.833  AvgNLL=35.177


Train 17: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 17  loss=0.5656


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


  Valid Top1/2/3=0.444/0.611/0.833  AvgNLL=38.773


Train 18: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 18  loss=0.4887


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.97it/s]


  Valid Top1/2/3=0.444/0.611/0.944  AvgNLL=40.133


Train 19: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 19  loss=0.5030


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]


  Valid Top1/2/3=0.389/0.556/0.778  AvgNLL=36.462


Train 20: 100%|██████████| 36/36 [00:22<00:00,  1.59it/s]


Epoch 20  loss=0.4280


Eval: 100%|██████████| 5/5 [00:01<00:00,  4.98it/s]

  Valid Top1/2/3=0.389/0.667/0.778  AvgNLL=38.343





In [None]:
"""
Dot-Product Rank-Head Trainer

Trains a GPT-2 encoder + linear rank head.
Top-1/2/3 accuracy + average softmax entropy on dev/test splits.
"""

import argparse, os, random, sys
import numpy as np, pandas as pd, torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import GPT2Tokenizer
from models.gpt2 import GPT2Model

TQDM_DISABLE=False

# --------------------------------------------------------------------- #
def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False
# --------------------------------------------------------------------- #
class AnswerDataset(Dataset):
    def __init__(self, df, tok, max_len=512):
        self.df, self.tok, self.max_len = df.reset_index(drop=True), tok, max_len
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        r=self.df.iloc[idx]
        choices=[r[f"choice_{i}"] for i in range(1,6)]
        texts=[f"{r.prefix}{c}{r.suffix}" for c in choices]
        lab=int(r.answer)-1
        return texts, lab
    def collate_fn(self,b):
        txts, labs=zip(*b); flat=sum(txts,[])
        enc=self.tok(flat,padding="max_length",truncation=True,
                     max_length=self.max_len,return_tensors="pt")
        B=len(b)
        return {"input_ids":enc.input_ids.view(B,5,-1),
                "attention_mask":enc.attention_mask.view(B,5,-1),
                "labels":torch.tensor(labs)}
# --------------------------------------------------------------------- #
class DotRankModel(nn.Module):
    def __init__(self,args):
        super().__init__()
        self.enc=GPT2Model.from_pretrained(model=args.model_size,
                                           d=args.d,l=args.l,num_heads=args.num_heads)
        self.head=nn.Linear(args.d,1)
        for p in self.parameters(): p.requires_grad=True
    def forward(self,ids,mask):
        B,C,L=ids.shape
        flat_ids,flat_mask=ids.view(B*C,L),mask.view(B*C,L)
        h=self.enc(flat_ids,attention_mask=flat_mask)["last_hidden_state"]
        last_idx=flat_mask.sum(1)-1
        pooled=h[torch.arange(h.size(0)), last_idx]  # (B*C,d)
        score=self.head(pooled).view(B,C)            # (B,5)
        return score
# --------------------------------------------------------------------- #
def evaluate(loader,model,dev):
    model.eval(); tot=top1=top2=top3=0; tot_ent=0
    with torch.no_grad():
        for bt in tqdm(loader,disable=TQDM_DISABLE,desc="Eval"):
            s=model(bt["input_ids"].to(dev),bt["attention_mask"].to(dev))
            rank=s.argsort(dim=1,descending=True)
            y=bt["labels"].to(dev); B=y.size(0)
            top1+=(rank[:,0]==y).sum().item()
            top2+=((rank[:,:2]==y.unsqueeze(1)).any(1)).sum().item()
            top3+=((rank[:,:3]==y.unsqueeze(1)).any(1)).sum().item()
            tot+=B
            probs=F.softmax(s,dim=1)
            ent=-(probs*probs.log()).sum(1).mean().item()
            tot_ent+=ent*B
    return top1/tot,top2/tot,top3/tot,tot_ent/tot
# --------------------------------------------------------------------- #
def train(args):
    seed_everything(args.seed)
    dev=torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu")
    df=pd.read_csv(args.data_path,encoding="utf-8-sig")
    tr,va=df[df.split=="train"],df[df.split=="valid"]
    tok=GPT2Tokenizer.from_pretrained(args.model_size); tok.pad_token=tok.eos_token
    tr_ds=AnswerDataset(tr,tok,args.max_length); va_ds=AnswerDataset(va,tok,args.max_length)
    tr_ld=DataLoader(tr_ds,batch_size=args.batch_size,shuffle=True,collate_fn=tr_ds.collate_fn)
    va_ld=DataLoader(va_ds,batch_size=args.batch_size,shuffle=False,collate_fn=va_ds.collate_fn)

    model=DotRankModel(args).to(dev)
    opt=torch.optim.AdamW(model.parameters(),lr=args.lr)
    best=0
    for ep in range(1,args.epochs+1):
        model.train(); ep_loss=0
        for bt in tqdm(tr_ld,disable=TQDM_DISABLE,desc=f"Train{ep}"):
            opt.zero_grad()
            s=model(bt["input_ids"].to(dev),bt["attention_mask"].to(dev))
            loss=F.cross_entropy(s,bt["labels"].to(dev))
            loss.backward(); opt.step()
            ep_loss+=loss.item()
        print(f"Epoch{ep} loss={ep_loss/len(tr_ld):.4f}")
        t1,t2,t3,ent=evaluate(va_ld,model,dev)
        print(f"  Valid Top1/2/3={t1:.3f}/{t2:.3f}/{t3:.3f}  AvgEntropy={ent:.3f}")
        if t1>best: best=t1; save(model,opt,args)
def save(m,o,args):
    os.makedirs(os.path.dirname(args.save_path),exist_ok=True)
    torch.save({"model":m.state_dict(),"optim":o.state_dict(),"args":args},args.save_path)
    print("Saved to",args.save_path)
def get_args():
    p=argparse.ArgumentParser()
    p.add_argument("--data_path",type=str,default="/content/drive/MyDrive/CSEG321/dataset/answer_margin.csv")
    p.add_argument("--model_size",type=str,default="gpt2",choices=["gpt2","gpt2-medium","gpt2-large"])
    p.add_argument("--batch_size",type=int,default=4)
    p.add_argument("--max_length",type=int,default=512)
    p.add_argument("--lr",type=float,default=5e-5)
    p.add_argument("--epochs",type=int,default=20)
    p.add_argument("--seed",type=int,default=42)
    p.add_argument("--use_gpu",action="store_true")
    # Jupyter friendly
    args = p.parse_args([])
    if torch.cuda.is_available(): args.use_gpu = True
    if args.model_size=="gpt2": args.d,args.l,args.num_heads=768,12,12
    elif args.model_size=="gpt2-medium": args.d,args.l,args.num_heads=1024,24,16
    else: args.d,args.l,args.num_heads=1280,36,20
    args.save_path=f"/content/drive/MyDrive/CSEG321/models/answer_dot_{args.model_size}_{args.epochs}e.pt"
    return args
if __name__=="__main__": train(get_args())


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Train1: 100%|██████████| 36/36 [00:18<00:00,  1.98it/s]


Epoch1 loss=1.6603


Eval: 100%|██████████| 5/5 [00:00<00:00,  5.22it/s]


  Valid Top1/2/3=0.278/0.444/0.556  AvgEntropy=1.609
Saved to /content/drive/MyDrive/CSEG321/models/answer_dot_gpt2_20e.pt


Train2: 100%|██████████| 36/36 [00:17<00:00,  2.11it/s]


Epoch2 loss=1.6133


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.36it/s]


  Valid Top1/2/3=0.278/0.500/0.667  AvgEntropy=1.609


Train3: 100%|██████████| 36/36 [00:17<00:00,  2.12it/s]


Epoch3 loss=1.6049


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.38it/s]


  Valid Top1/2/3=0.278/0.389/0.500  AvgEntropy=1.606


Train4: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch4 loss=1.6261


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.36it/s]


  Valid Top1/2/3=0.444/0.778/0.889  AvgEntropy=1.609
Saved to /content/drive/MyDrive/CSEG321/models/answer_dot_gpt2_20e.pt


Train5: 100%|██████████| 36/36 [00:17<00:00,  2.11it/s]


Epoch5 loss=1.6040


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.37it/s]


  Valid Top1/2/3=0.667/0.889/0.889  AvgEntropy=1.608
Saved to /content/drive/MyDrive/CSEG321/models/answer_dot_gpt2_20e.pt


Train6: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch6 loss=1.5537


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.38it/s]


  Valid Top1/2/3=0.444/0.722/0.889  AvgEntropy=1.570


Train7: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch7 loss=1.4577


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.38it/s]


  Valid Top1/2/3=0.278/0.722/0.944  AvgEntropy=1.517


Train8: 100%|██████████| 36/36 [00:17<00:00,  2.10it/s]


Epoch8 loss=1.4396


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.18it/s]


  Valid Top1/2/3=0.500/0.889/0.944  AvgEntropy=1.527


Train9: 100%|██████████| 36/36 [00:17<00:00,  2.12it/s]


Epoch9 loss=1.3436


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.23it/s]


  Valid Top1/2/3=0.389/0.833/0.889  AvgEntropy=1.405


Train10: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch10 loss=1.2164


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.38it/s]


  Valid Top1/2/3=0.667/0.889/1.000  AvgEntropy=1.318


Train11: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch11 loss=1.1365


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.35it/s]


  Valid Top1/2/3=0.611/0.889/1.000  AvgEntropy=1.349


Train12: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch12 loss=1.1324


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.37it/s]


  Valid Top1/2/3=0.611/1.000/1.000  AvgEntropy=1.337


Train13: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch13 loss=1.0060


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.36it/s]


  Valid Top1/2/3=0.556/0.944/0.944  AvgEntropy=1.290


Train14: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch14 loss=0.8588


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.38it/s]


  Valid Top1/2/3=0.611/0.944/1.000  AvgEntropy=1.209


Train15: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch15 loss=0.9100


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.38it/s]


  Valid Top1/2/3=0.556/0.833/1.000  AvgEntropy=1.253


Train16: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch16 loss=0.7174


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.37it/s]


  Valid Top1/2/3=0.722/0.944/0.944  AvgEntropy=1.200
Saved to /content/drive/MyDrive/CSEG321/models/answer_dot_gpt2_20e.pt


Train17: 100%|██████████| 36/36 [00:17<00:00,  2.12it/s]


Epoch17 loss=0.6561


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.39it/s]


  Valid Top1/2/3=0.722/0.944/1.000  AvgEntropy=1.210


Train18: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch18 loss=0.5209


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.36it/s]


  Valid Top1/2/3=0.611/1.000/1.000  AvgEntropy=1.113


Train19: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch19 loss=0.4985


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.37it/s]


  Valid Top1/2/3=0.556/0.889/1.000  AvgEntropy=1.157


Train20: 100%|██████████| 36/36 [00:16<00:00,  2.12it/s]


Epoch20 loss=0.4446


Eval: 100%|██████████| 5/5 [00:00<00:00,  6.21it/s]

  Valid Top1/2/3=0.556/0.833/0.944  AvgEntropy=1.176



