In [9]:
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [10]:
!git clone https://github.com/ItWasAllYellow/public_cs224n_gpt.git
%cd public_cs224n_gpt

Cloning into 'public_cs224n_gpt'...
remote: Enumerating objects: 69, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 69 (delta 21), reused 13 (delta 13), pack-reused 32 (from 1)[K
Receiving objects: 100% (69/69), 30.87 MiB | 25.66 MiB/s, done.
Resolving deltas: 100% (22/22), done.
/content/public_cs224n_gpt/public_cs224n_gpt


In [11]:
"""
Answer Margin Ranking Trainer

Trains a GPT-2 based model to rank multiple-choice answers by minimizing margin loss
on perplexity scores. Generates Top-1/2/3 accuracy and average perplexity on dev/test splits.
"""

import argparse
import os
import random
import sys

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import GPT2Tokenizer
from models.gpt2 import GPT2Model      # public_cs224n_gpt implementation
from transformers import GPT2LMHeadModel

# suppress TQDM if needed
TQDM_DISABLE = False

def seed_everything(seed: int = 42):
    """Fix random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class AnswerMarginModel(nn.Module):
    """Wrapper around GPT2Model + LM head for computing token logits."""
    def __init__(self, args):
        super().__init__()
        # backbone
        self.gpt = GPT2Model.from_pretrained(
            model=args.model_size,
            d=args.d,
            l=args.l,
            num_heads=args.num_heads
        )
        # LM head
        self.lm_head = nn.Linear(args.d, GPT2Tokenizer.from_pretrained(args.model_size).vocab_size, bias=False)
        # fine-tune everything
        for p in self.gpt.parameters():
            p.requires_grad = True
        for p in self.lm_head.parameters():
            p.requires_grad = True

    def forward(self, input_ids, attention_mask):
        """Returns token-level logits of shape (B, T, V)."""
        out = self.gpt(input_ids=input_ids, attention_mask=attention_mask)
        hidden = out["last_hidden_state"]
        logits = self.lm_head(hidden)
        return logits

def calculate_perplexity(logits, labels, attention_mask):
    """
    Compute per-example perplexity:
    - logits: (B*T, V)
    - labels & mask: (B*T)
    Returns perp of shape (B*T) then can be reshaped.
    """
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    shift_mask   = attention_mask[..., 1:].contiguous()

    loss_fn = nn.CrossEntropyLoss(reduction="none")
    loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)),
                   shift_labels.view(-1))
    loss = loss.view_as(shift_labels)
    loss = loss * shift_mask

    sum_loss = loss.sum(dim=1)
    token_count = shift_mask.sum(dim=1).clamp_min(1)
    mean_loss = sum_loss / token_count
    return torch.exp(mean_loss)

class MarginLoss(nn.Module):
    """Hinge-style margin loss between correct and incorrect PPLs."""
    def __init__(self, margin: float = 1.0):
        super().__init__()
        self.margin = margin

    def forward(self, correct_ppl, incorrect_ppls):
        # correct: (B,), incorrect: (B, C-1)
        correct = correct_ppl.unsqueeze(1).expand_as(incorrect_ppls)
        loss = F.relu(correct - incorrect_ppls + self.margin)
        return loss.mean()

class AnswerMarginDataset(Dataset):
    """Dataset for packing 5 candidate texts per example."""
    def __init__(self, df: pd.DataFrame, tokenizer: GPT2Tokenizer, max_length: int = 512):
        self.df = df.reset_index(drop=True)
        self.tok = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        prefix = row["prefix"]
        suffix = row["suffix"]
        choices = [row[f"choice_{i}"] for i in range(1,6)]
        example_texts = [f"{prefix}{c}{suffix}" for c in choices]
        label = int(row["answer"]) - 1
        return example_texts, label

    def collate_fn(self, batch):
        texts, labels = zip(*batch)
        flat_texts = sum(texts, [])
        enc = self.tok.batch_encode_plus(
            flat_texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        B = len(batch)
        input_ids = enc["input_ids"].view(B, 5, -1)
        attention_mask = enc["attention_mask"].view(B, 5, -1)
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": torch.tensor(labels, dtype=torch.long)
        }

def load_data(path: str, split: str = None):
    """Load CSV and optionally filter by split column."""
    df = pd.read_csv(path, encoding="utf-8-sig")
    return df[df.split == split].reset_index(drop=True) if split else df

def save_model(model: nn.Module, optimizer, args, path: str):
    """Persist the model + optimizer state."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save({
        "model": model.state_dict(),
        "optim": optimizer.state_dict(),
        "args": args
    }, path)
    print(f"Saved model to {path}")

def evaluate(dataloader, model, lm_head_model, device):
    """Compute Top-1/2/3 accuracies and average PPL on a split."""
    model.eval()
    total = top1 = top2 = top3 = 0
    total_ppl = 0.0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", disable=TQDM_DISABLE):
            B, C, L = batch["input_ids"].shape
            ids = batch["input_ids"].view(B*C, L).to(device)
            mask = batch["attention_mask"].view(B*C, L).to(device)
            logits = model(ids, mask)
            ppl = calculate_perplexity(logits, ids, mask).view(B, C)

            total += B
            ranked = torch.argsort(ppl, dim=1)
            labels = batch["labels"].to(device)
            top1 += (ranked[:,0] == labels).sum().item()
            top2 += ((ranked[:,:2] == labels.unsqueeze(1)).any(1)).sum().item()
            top3 += ((ranked[:,:3] == labels.unsqueeze(1)).any(1)).sum().item()
            total_ppl += ppl.mean().item() * B

    return top1/total, top2/total, top3/total, total_ppl/total

def train(args):
    seed_everything(args.seed)
    device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # 1) 전체 CSV 로드 (split·grade 컬럼 포함)
    df = pd.read_csv(args.data_path, encoding="utf-8-sig")

    tokenizer = GPT2Tokenizer.from_pretrained(args.model_size)
    tokenizer.pad_token = tokenizer.eos_token
    model = AnswerMarginModel(args).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    loss_fn = MarginLoss(margin=args.margin)

    best_top1 = 0.0

    # 2) Curriculum: grade 2 ➔ grade 3
    for grade in [2, 3]:
        print(f"\n===== Training on grade {grade} =====")
        train_df = df[(df.split == "train") & (df.grade == grade)]
        dev_df   = df[(df.split == "valid") & (df.grade == grade)]
        if len(train_df) == 0:
            print(f"  (no train data for grade {grade}, skip)")
            continue

        train_ds = AnswerMarginDataset(train_df, tokenizer, args.max_length)
        dev_ds   = AnswerMarginDataset(dev_df, tokenizer, args.max_length)
        train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
                                  collate_fn=train_ds.collate_fn)
        dev_loader   = DataLoader(dev_ds, batch_size=args.batch_size, shuffle=False,
                                  collate_fn=dev_ds.collate_fn)

        # 3) 기존 epoch loop
        for epoch in range(1, args.epochs + 1):
            model.train()
            epoch_loss = 0.0
            for batch in tqdm(train_loader, desc=f"Grade {grade} Epoch {epoch}", disable=TQDM_DISABLE):
                optimizer.zero_grad()
                B, C, L = batch["input_ids"].shape
                ids  = batch["input_ids"].view(B*C, L).to(device)
                mask = batch["attention_mask"].view(B*C, L).to(device)
                logits = model(ids, mask)
                ppl = calculate_perplexity(logits, ids, mask).view(B, C)
                correct = ppl[torch.arange(B), batch["labels"].to(device)]
                wrongs  = torch.stack([ppl[i, torch.arange(C)!=batch["labels"][i]] for i in range(B)])
                loss = loss_fn(correct, wrongs)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()

            avg_loss = epoch_loss / len(train_loader)
            top1, top2, top3, avg_ppl = evaluate(dev_loader, model, None, device)
            print(f"[G{grade} E{epoch}] Loss: {avg_loss:.4f} | Dev Top1/2/3: {top1:.4f}/{top2:.4f}/{top3:.4f} | PPL: {avg_ppl:.4f}")

            if top1 > best_top1:
                best_top1 = top1
                save_model(model, optimizer, args, args.save_path)


def get_args():
    """Parse command-line arguments (ignores Jupyter args)."""
    p = argparse.ArgumentParser()
    p.add_argument("--data_path",    type=str, default="/content/drive/MyDrive/CSEG321/dataset/answer_margin.csv")
    p.add_argument("--model_size",   type=str, default="gpt2", choices=["gpt2","gpt2-medium","gpt2-large"])
    p.add_argument("--batch_size",   type=int, default=4)
    p.add_argument("--max_length",   type=int, default=512)
    p.add_argument("--lr",           type=float, default=5e-5)
    p.add_argument("--epochs",       type=int, default=50)
    p.add_argument("--margin",       type=float, default=1000.0)
    p.add_argument("--seed",         type=int, default=42)
    p.add_argument("--use_gpu",      action="store_true")
    # Jupyter friendly
    args = p.parse_args([])
    if torch.cuda.is_available(): args.use_gpu = True
    # set model dims
    if args.model_size == "gpt2":
        args.d, args.l, args.num_heads = 768,12,12
    elif args.model_size == "gpt2-medium":
        args.d, args.l, args.num_heads = 1024,24,16
    else:
        args.d, args.l, args.num_heads = 1280,36,20
    # save path
    args.save_path = f"/content/drive/MyDrive/CSEG321/models/answer_margin_{args.model_size}_{args.epochs}e_{args.lr}lr_curriculum.pt"
    return args

if __name__ == "__main__":
    args = get_args()
    train(args)

Using device: cuda

===== Training on grade 2 =====


Grade 2 Epoch 1: 100%|██████████| 7/7 [00:04<00:00,  1.56it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.73it/s]


[G2 E1] Loss: 6165411.5580 | Dev Top1/2/3: 0.0000/0.0000/0.6667 | PPL: 2117524.0000


Grade 2 Epoch 2: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E2] Loss: 21581.6561 | Dev Top1/2/3: 0.0000/0.0000/0.3333 | PPL: 432843.0312


Grade 2 Epoch 3: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


[G2 E3] Loss: 16764.6968 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 309505.9375


Grade 2 Epoch 4: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.80it/s]


[G2 E4] Loss: 17707.9414 | Dev Top1/2/3: 0.0000/0.0000/0.3333 | PPL: 278344.2500


Grade 2 Epoch 5: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E5] Loss: 17896.6770 | Dev Top1/2/3: 0.0000/0.0000/0.3333 | PPL: 261020.1406


Grade 2 Epoch 6: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E6] Loss: 21485.9679 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 242004.6562


Grade 2 Epoch 7: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


[G2 E7] Loss: 12104.4681 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 225571.7812


Grade 2 Epoch 8: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E8] Loss: 9882.5342 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 214465.3438


Grade 2 Epoch 9: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.83it/s]


[G2 E9] Loss: 10472.4441 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 205768.2656


Grade 2 Epoch 10: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


[G2 E10] Loss: 10800.5639 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 197926.4375


Grade 2 Epoch 11: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.89it/s]


[G2 E11] Loss: 8093.2079 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 190883.7031


Grade 2 Epoch 12: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


[G2 E12] Loss: 8359.8241 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 185550.7031


Grade 2 Epoch 13: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E13] Loss: 6867.1967 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 181809.8125


Grade 2 Epoch 14: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


[G2 E14] Loss: 5113.9969 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 179099.4688


Grade 2 Epoch 15: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


[G2 E15] Loss: 5031.2580 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 177254.5938


Grade 2 Epoch 16: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.89it/s]


[G2 E16] Loss: 6262.8067 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 175334.3125


Grade 2 Epoch 17: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E17] Loss: 4199.3024 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 173640.0156


Grade 2 Epoch 18: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


[G2 E18] Loss: 7729.1819 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 171736.4375


Grade 2 Epoch 19: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E19] Loss: 3387.9133 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 170248.4062


Grade 2 Epoch 20: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E20] Loss: 5114.2681 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 169128.4062


Grade 2 Epoch 21: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


[G2 E21] Loss: 3824.8952 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 167768.0156


Grade 2 Epoch 22: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


[G2 E22] Loss: 4857.2440 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 166606.8438


Grade 2 Epoch 23: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E23] Loss: 5250.3555 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 165334.6719


Grade 2 Epoch 24: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E24] Loss: 4705.2095 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 164216.3281


Grade 2 Epoch 25: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E25] Loss: 6237.2081 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 163028.9062


Grade 2 Epoch 26: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E26] Loss: 8086.3982 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 161619.5781


Grade 2 Epoch 27: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E27] Loss: 3932.7049 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 160467.6406


Grade 2 Epoch 28: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


[G2 E28] Loss: 2727.1856 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 159600.5469


Grade 2 Epoch 29: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E29] Loss: 3861.8666 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 158788.8438


Grade 2 Epoch 30: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E30] Loss: 5006.4152 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 157835.3438


Grade 2 Epoch 31: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


[G2 E31] Loss: 3969.7880 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 157027.6250


Grade 2 Epoch 32: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E32] Loss: 4289.8228 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 156083.2344


Grade 2 Epoch 33: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


[G2 E33] Loss: 2493.2833 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 155345.2031


Grade 2 Epoch 34: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


[G2 E34] Loss: 2888.7756 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 154701.8125


Grade 2 Epoch 35: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E35] Loss: 3927.2633 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 153900.0469


Grade 2 Epoch 36: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.81it/s]


[G2 E36] Loss: 4170.8018 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 153065.7812


Grade 2 Epoch 37: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E37] Loss: 2878.9086 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 152226.3750


Grade 2 Epoch 38: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E38] Loss: 2965.6642 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 151426.6406


Grade 2 Epoch 39: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E39] Loss: 2280.2045 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 150732.2031


Grade 2 Epoch 40: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


[G2 E40] Loss: 4154.2773 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 150003.1719


Grade 2 Epoch 41: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E41] Loss: 4170.8717 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 149086.3281


Grade 2 Epoch 42: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


[G2 E42] Loss: 3366.9833 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 148247.3594


Grade 2 Epoch 43: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E43] Loss: 3166.0767 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 147377.9688


Grade 2 Epoch 44: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E44] Loss: 3171.7959 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 146662.8125


Grade 2 Epoch 45: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


[G2 E45] Loss: 3224.9565 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 146129.3438


Grade 2 Epoch 46: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


[G2 E46] Loss: 3427.5983 | Dev Top1/2/3: 0.0000/0.3333/0.3333 | PPL: 145521.8594


Grade 2 Epoch 47: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


[G2 E47] Loss: 1722.1593 | Dev Top1/2/3: 0.0000/0.0000/0.3333 | PPL: 144974.5781


Grade 2 Epoch 48: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.88it/s]


[G2 E48] Loss: 3236.7701 | Dev Top1/2/3: 0.0000/0.0000/0.3333 | PPL: 144389.5938


Grade 2 Epoch 49: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


[G2 E49] Loss: 4100.9717 | Dev Top1/2/3: 0.0000/0.0000/0.3333 | PPL: 143655.3125


Grade 2 Epoch 50: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


[G2 E50] Loss: 2976.3236 | Dev Top1/2/3: 0.0000/0.0000/0.3333 | PPL: 142942.2812

===== Training on grade 3 =====


Grade 3 Epoch 1: 100%|██████████| 29/29 [00:18<00:00,  1.58it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.70it/s]


[G3 E1] Loss: 4062.8550 | Dev Top1/2/3: 0.3333/0.4000/0.7333 | PPL: 143993.9948
Saved model to /content/drive/MyDrive/CSEG321/models/answer_margin_gpt2_50e_5e-05lr_curriculum.pt


Grade 3 Epoch 2: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.74it/s]


[G3 E2] Loss: 4263.2326 | Dev Top1/2/3: 0.3333/0.4000/0.8000 | PPL: 140054.9115


Grade 3 Epoch 3: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E3] Loss: 2832.0170 | Dev Top1/2/3: 0.3333/0.4667/0.8667 | PPL: 137003.7687


Grade 3 Epoch 4: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E4] Loss: 3098.3566 | Dev Top1/2/3: 0.3333/0.5333/0.8667 | PPL: 133921.2437


Grade 3 Epoch 5: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E5] Loss: 3114.4134 | Dev Top1/2/3: 0.3333/0.5333/0.8667 | PPL: 130861.8365


Grade 3 Epoch 6: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E6] Loss: 3113.6907 | Dev Top1/2/3: 0.3333/0.5333/0.8667 | PPL: 127498.7469


Grade 3 Epoch 7: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E7] Loss: 2538.1718 | Dev Top1/2/3: 0.3333/0.5333/0.8667 | PPL: 124381.1453


Grade 3 Epoch 8: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


[G3 E8] Loss: 2837.8499 | Dev Top1/2/3: 0.3333/0.5333/0.8667 | PPL: 121864.7156


Grade 3 Epoch 9: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E9] Loss: 2744.5582 | Dev Top1/2/3: 0.3333/0.5333/0.8667 | PPL: 118631.5203


Grade 3 Epoch 10: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E10] Loss: 2403.0224 | Dev Top1/2/3: 0.3333/0.6000/0.8667 | PPL: 115724.2901


Grade 3 Epoch 11: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


[G3 E11] Loss: 2628.5801 | Dev Top1/2/3: 0.3333/0.6000/0.8667 | PPL: 113131.5312


Grade 3 Epoch 12: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E12] Loss: 2141.6130 | Dev Top1/2/3: 0.3333/0.5333/0.8667 | PPL: 111068.0208


Grade 3 Epoch 13: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E13] Loss: 2107.7468 | Dev Top1/2/3: 0.3333/0.5333/0.8667 | PPL: 109065.9156


Grade 3 Epoch 14: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E14] Loss: 2016.1845 | Dev Top1/2/3: 0.3333/0.5333/0.8667 | PPL: 106860.6453


Grade 3 Epoch 15: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E15] Loss: 2570.6336 | Dev Top1/2/3: 0.3333/0.5333/0.8667 | PPL: 105061.7411


Grade 3 Epoch 16: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


[G3 E16] Loss: 2042.8964 | Dev Top1/2/3: 0.3333/0.5333/0.8667 | PPL: 102920.2495


Grade 3 Epoch 17: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E17] Loss: 1992.1069 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 100660.5766


Grade 3 Epoch 18: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E18] Loss: 1963.9209 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 98547.8328


Grade 3 Epoch 19: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E19] Loss: 1767.1093 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 95957.7385


Grade 3 Epoch 20: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E20] Loss: 1666.6322 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 94236.2385


Grade 3 Epoch 21: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E21] Loss: 1555.4948 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 93104.4943


Grade 3 Epoch 22: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E22] Loss: 1605.5888 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 91135.6516


Grade 3 Epoch 23: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.76it/s]


[G3 E23] Loss: 1496.5890 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 89903.2859


Grade 3 Epoch 24: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


[G3 E24] Loss: 1490.9900 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 89144.9068


Grade 3 Epoch 25: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


[G3 E25] Loss: 1525.2930 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 87862.6271


Grade 3 Epoch 26: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E26] Loss: 1394.6156 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 86655.1359


Grade 3 Epoch 27: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E27] Loss: 1387.9893 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 85227.1953


Grade 3 Epoch 28: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.75it/s]


[G3 E28] Loss: 1405.8783 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 83963.8292


Grade 3 Epoch 29: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


[G3 E29] Loss: 1422.4267 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 82707.7990


Grade 3 Epoch 30: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


[G3 E30] Loss: 1182.5393 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 81504.7583


Grade 3 Epoch 31: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E31] Loss: 1181.4760 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 80518.4875


Grade 3 Epoch 32: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E32] Loss: 1086.6819 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 80093.7703


Grade 3 Epoch 33: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E33] Loss: 1256.1257 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 79256.7375


Grade 3 Epoch 34: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E34] Loss: 742.7885 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 78497.5271


Grade 3 Epoch 35: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


[G3 E35] Loss: 863.5891 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 77944.9573


Grade 3 Epoch 36: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E36] Loss: 884.3779 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 77667.8161


Grade 3 Epoch 37: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E37] Loss: 934.9459 | Dev Top1/2/3: 0.2667/0.4667/0.8667 | PPL: 76650.1958


Grade 3 Epoch 38: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E38] Loss: 921.9475 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 76278.6818


Grade 3 Epoch 39: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E39] Loss: 914.6718 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 75356.6542


Grade 3 Epoch 40: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E40] Loss: 825.3791 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 74681.4010


Grade 3 Epoch 41: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E41] Loss: 677.1782 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 74424.4073


Grade 3 Epoch 42: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


[G3 E42] Loss: 741.6247 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 73926.4000


Grade 3 Epoch 43: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E43] Loss: 598.1801 | Dev Top1/2/3: 0.2667/0.5333/0.8667 | PPL: 72548.8859


Grade 3 Epoch 44: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E44] Loss: 670.4879 | Dev Top1/2/3: 0.2000/0.5333/0.8667 | PPL: 71955.4875


Grade 3 Epoch 45: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E45] Loss: 732.1686 | Dev Top1/2/3: 0.2000/0.5333/0.8667 | PPL: 71249.0844


Grade 3 Epoch 46: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E46] Loss: 805.8449 | Dev Top1/2/3: 0.2000/0.5333/0.8667 | PPL: 70815.6984


Grade 3 Epoch 47: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.79it/s]


[G3 E47] Loss: 602.0596 | Dev Top1/2/3: 0.2000/0.5333/0.8667 | PPL: 70726.9911


Grade 3 Epoch 48: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


[G3 E48] Loss: 649.1087 | Dev Top1/2/3: 0.2000/0.5333/0.8667 | PPL: 70184.9724


Grade 3 Epoch 49: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.76it/s]


[G3 E49] Loss: 683.1294 | Dev Top1/2/3: 0.2000/0.5333/0.8667 | PPL: 69676.2745


Grade 3 Epoch 50: 100%|██████████| 29/29 [00:18<00:00,  1.60it/s]
Evaluating: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]

[G3 E50] Loss: 564.1792 | Dev Top1/2/3: 0.2000/0.5333/0.8667 | PPL: 69234.3208





In [None]:
"""
NLL-Ranking Trainer

Trains a GPT-2 model to pick the correct option by minimizing cross-entropy
on Negative-Log-Likelihood scores. Top-1/2/3 accuracy + average NLL.
"""

import argparse, os, random, sys, math
import numpy as np, pandas as pd, torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import GPT2Tokenizer
from models.gpt2 import GPT2Model     # custom GPT-2 implementation

TQDM_DISABLE = False  # set True in Colab if bar flickers

# --------------------------------------------------------------------- #
# 0. Utils
# --------------------------------------------------------------------- #
def seed_everything(seed: int = 42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --------------------------------------------------------------------- #
# 1. Dataset
# --------------------------------------------------------------------- #
class AnswerDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: GPT2Tokenizer, max_len=512):
        self.df, self.tok, self.max_len = df.reset_index(drop=True), tokenizer, max_len

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        choices = [row[f"choice_{i}"] for i in range(1, 6)]
        texts   = [f"{row.prefix}{c}{row.suffix}" for c in choices]
        label   = int(row.answer) - 1
        return texts, label

    def collate_fn(self, batch):
        texts, labels  = zip(*batch)
        flat = sum(texts, [])          # (B*5,) list
        enc  = self.tok(
            flat, padding="max_length", truncation=True,
            max_length=self.max_len, return_tensors="pt"
        )
        B = len(batch)
        return {
            "input_ids":     enc["input_ids"].view(B, 5, -1),
            "attention_mask":enc["attention_mask"].view(B, 5, -1),
            "labels":        torch.tensor(labels, dtype=torch.long)
        }

# --------------------------------------------------------------------- #
# 2. Model: GPT-2 + LM-head to compute NLL
# --------------------------------------------------------------------- #
class NLLRankModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.gpt = GPT2Model.from_pretrained(
            model=args.model_size, d=args.d, l=args.l, num_heads=args.num_heads
        )
        vocab = GPT2Tokenizer.from_pretrained(args.model_size).vocab_size
        self.lm_head = nn.Linear(args.d, vocab, bias=False)
        for p in self.parameters(): p.requires_grad = True

    def forward(self, input_ids, attention_mask):
        """
        input_ids: (B, C, L)   attention_mask: (B, C, L)
        returns nll: (B, C)  -- lower is better
        """
        B, C, L = input_ids.shape
        flat_ids  = input_ids.view(B*C, L)
        flat_mask = attention_mask.view(B*C, L)

        h = self.gpt(flat_ids, attention_mask=flat_mask)["last_hidden_state"]
        logits = self.lm_head(h)                       # (B*C, L, V)

        # shift for LM loss
        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = flat_ids[:, 1:].contiguous()
        shift_mask   = flat_mask[:, 1:].contiguous()

        loss_ = F.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
            reduction="none"
        ).view(flat_ids.size(0), -1) * shift_mask

        # token average NLL
        nll = loss_.sum(1) / shift_mask.sum(1).clamp_min(1)   # (B*C,)
        return nll.view(B, C)                                 # (B,5)

# --------------------------------------------------------------------- #
# 3. Train / Eval
# --------------------------------------------------------------------- #
def evaluate(loader, model, device):
    model.eval()
    tot = top1 = top2 = top3 = 0; tot_nll = 0.0
    with torch.no_grad():
        for batch in tqdm(loader, disable=TQDM_DISABLE, desc="Eval"):
            ids, mask = batch["input_ids"].to(device), batch["attention_mask"].to(device)
            nll = model(ids, mask)                 # (B, 5)
            scores = -nll                          # higher better
            rank = scores.argsort(dim=1, descending=True)
            y = batch["labels"].to(device)
            B = y.size(0)
            top1 += (rank[:, 0] == y).sum().item()
            top2 += ((rank[:, :2] == y.unsqueeze(1)).any(1)).sum().item()
            top3 += ((rank[:, :3] == y.unsqueeze(1)).any(1)).sum().item()
            tot  += B
            tot_nll += nll.mean().item()*B
    return top1/tot, top2/tot, top3/tot, tot_nll/tot

def train(args):
    seed_everything(args.seed)
    dev = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu")
    print("Device:", dev)

    # load full CSV once
    df  = pd.read_csv(args.data_path, encoding="utf-8-sig")

    # prepare tokenizer (unchanged)
    tok = GPT2Tokenizer.from_pretrained(args.model_size)
    tok.pad_token = tok.eos_token

    # model & optimizer (unchanged)
    model = NLLRankModel(args).to(dev)
    optim = torch.optim.AdamW(model.parameters(), lr=args.lr)
    best = 0.0

    # ────────────────────────────────────────────────────
    #  Curriclum Learning: train on grade=2 then grade=3
    # ────────────────────────────────────────────────────
    for grade in [2, 3]:
        print(f"\n===== Curriculum Stage: Grade {grade} =====")
        # filter train/valid by grade
        trn = df[(df.split == "train") & (df.grade == grade)]
        val = df[(df.split == "valid") & (df.grade == grade)]
        if len(trn) == 0:
            print(f"No grade {grade} training data, skipping.")
            continue

        train_ds = AnswerDataset(trn, tok, args.max_length)
        val_ds   = AnswerDataset(val, tok, args.max_length)
        tr_loader= DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,  collate_fn=train_ds.collate_fn)
        val_loader=DataLoader(val_ds,   batch_size=args.batch_size, shuffle=False, collate_fn=val_ds.collate_fn)

        # now run your existing epoch loop on this subset
        for ep in range(1, args.epochs+1):
            model.train(); ep_loss = 0.0
            for batch in tqdm(tr_loader, disable=TQDM_DISABLE, desc=f"Train G{grade} E{ep}"):
                optim.zero_grad()
                ids, mask = batch["input_ids"].to(dev), batch["attention_mask"].to(dev)
                nll = model(ids, mask)           # (B,5)
                loss = F.cross_entropy(-nll, batch["labels"].to(dev))
                loss.backward(); optim.step()
                ep_loss += loss.item()
            print(f"Grade {grade} Epoch {ep}  loss={ep_loss/len(tr_loader):.4f}")

            # evaluation (unchanged)
            t1,t2,t3,av_nll = evaluate(val_loader, model, dev)
            print(f"  Valid G{grade} Top1/2/3={t1:.3f}/{t2:.3f}/{t3:.3f}  AvgNLL={av_nll:.3f}")

            # checkpoint
            if t1 > best:
                best = t1
                save(model, optim, args)

def save(m, o, args):
    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
    torch.save({"model":m.state_dict(),"optim":o.state_dict(),"args":args}, args.save_path)
    print("Saved to", args.save_path)

def get_args():
    p=argparse.ArgumentParser();  # identical to original
    p.add_argument("--data_path", type=str, default="/content/drive/MyDrive/CSEG321/dataset/answer_margin.csv")
    p.add_argument("--model_size",type=str, default="gpt2", choices=["gpt2","gpt2-medium","gpt2-large"])
    p.add_argument("--batch_size", type=int, default=4)
    p.add_argument("--max_length", type=int, default=512)
    p.add_argument("--lr",        type=float, default=5e-5)
    p.add_argument("--epochs",    type=int, default=50)
    p.add_argument("--seed",      type=int, default=42)
    p.add_argument("--use_gpu",   action="store_true")
    # Jupyter friendly
    args = p.parse_args([])
    if torch.cuda.is_available(): args.use_gpu = True
    if args.model_size=="gpt2":    args.d,args.l,args.num_heads=768,12,12
    elif args.model_size=="gpt2-medium": args.d,args.l,args.num_heads=1024,24,16
    else: args.d,args.l,args.num_heads=1280,36,20
    args.save_path=f"/content/drive/MyDrive/CSEG321/models/answer_nll_{args.model_size}_{args.epochs}e_curriculum.pt"
    return args

if __name__=="__main__":
    train(get_args())


Device: cuda

===== Curriculum Stage: Grade 2 =====


Train G2 E1: 100%|██████████| 7/7 [00:04<00:00,  1.56it/s]


Grade 2 Epoch 1  loss=1.6303


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.73it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=19.350
Saved to /content/drive/MyDrive/CSEG321/models/answer_nll_gpt2_50e_curriculum.pt


Train G2 E2: 100%|██████████| 7/7 [00:04<00:00,  1.57it/s]


Grade 2 Epoch 2  loss=1.6451


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=18.667


Train G2 E3: 100%|██████████| 7/7 [00:04<00:00,  1.58it/s]


Grade 2 Epoch 3  loss=1.5591


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.70it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=18.394


Train G2 E4: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 4  loss=1.5412


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=18.370


Train G2 E5: 100%|██████████| 7/7 [00:04<00:00,  1.58it/s]


Grade 2 Epoch 5  loss=1.5055


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.79it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=18.383


Train G2 E6: 100%|██████████| 7/7 [00:04<00:00,  1.58it/s]


Grade 2 Epoch 6  loss=1.3937


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.333/0.667/0.667  AvgNLL=18.324


Train G2 E7: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 7  loss=1.3379


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=19.482


Train G2 E8: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 8  loss=1.1698


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.80it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=20.185


Train G2 E9: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 9  loss=1.0373


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.82it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=20.431


Train G2 E10: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 10  loss=0.9105


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=21.898


Train G2 E11: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 11  loss=0.7529


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.87it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=23.950


Train G2 E12: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 12  loss=0.6445


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.333/1.000/1.000  AvgNLL=22.716


Train G2 E13: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 13  loss=0.6028


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.82it/s]


  Valid G2 Top1/2/3=0.333/0.333/0.667  AvgNLL=26.169


Train G2 E14: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 14  loss=0.6098


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=23.698
Saved to /content/drive/MyDrive/CSEG321/models/answer_nll_gpt2_50e_curriculum.pt


Train G2 E15: 100%|██████████| 7/7 [00:04<00:00,  1.57it/s]


Grade 2 Epoch 15  loss=0.5060


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.333/0.333/1.000  AvgNLL=25.074


Train G2 E16: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 16  loss=0.4382


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.80it/s]


  Valid G2 Top1/2/3=0.333/0.333/0.667  AvgNLL=27.135


Train G2 E17: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 17  loss=0.4803


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=27.120


Train G2 E18: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 18  loss=0.4910


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=27.461


Train G2 E19: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 19  loss=0.3415


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=27.062


Train G2 E20: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 20  loss=0.3904


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=28.750


Train G2 E21: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 21  loss=0.3658


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=28.136


Train G2 E22: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 22  loss=0.3495


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.79it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=27.822


Train G2 E23: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 23  loss=0.3724


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.83it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=29.453


Train G2 E24: 100%|██████████| 7/7 [00:04<00:00,  1.58it/s]


Grade 2 Epoch 24  loss=0.3673


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=28.965


Train G2 E25: 100%|██████████| 7/7 [00:04<00:00,  1.58it/s]


Grade 2 Epoch 25  loss=0.3390


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=29.717


Train G2 E26: 100%|██████████| 7/7 [00:04<00:00,  1.58it/s]


Grade 2 Epoch 26  loss=0.2844


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=30.043


Train G2 E27: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 27  loss=0.3041


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.82it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=30.838


Train G2 E28: 100%|██████████| 7/7 [00:04<00:00,  1.58it/s]


Grade 2 Epoch 28  loss=0.2605


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.81it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=31.487


Train G2 E29: 100%|██████████| 7/7 [00:04<00:00,  1.58it/s]


Grade 2 Epoch 29  loss=0.2921


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.333/0.333/0.667  AvgNLL=34.546


Train G2 E30: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 30  loss=0.2555


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.74it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=30.943


Train G2 E31: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 31  loss=0.1885


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=34.139


Train G2 E32: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 32  loss=0.2272


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=32.264


Train G2 E33: 100%|██████████| 7/7 [00:04<00:00,  1.58it/s]


Grade 2 Epoch 33  loss=0.2079


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.82it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=32.231


Train G2 E34: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 34  loss=0.1654


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=34.720


Train G2 E35: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 35  loss=0.2124


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=34.678


Train G2 E36: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 36  loss=0.1816


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=33.604


Train G2 E37: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 37  loss=0.2072


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=33.253


Train G2 E38: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 38  loss=0.1870


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.80it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=35.979


Train G2 E39: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 39  loss=0.1742


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=37.808


Train G2 E40: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 40  loss=0.1705


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.333/0.667/1.000  AvgNLL=37.971


Train G2 E41: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 41  loss=0.1486


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=37.491


Train G2 E42: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 42  loss=0.1981


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=36.930


Train G2 E43: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 43  loss=0.1471


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=37.921


Train G2 E44: 100%|██████████| 7/7 [00:04<00:00,  1.58it/s]


Grade 2 Epoch 44  loss=0.1429


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=39.087


Train G2 E45: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 45  loss=0.1194


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=38.745


Train G2 E46: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 46  loss=0.1711


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=39.401


Train G2 E47: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 47  loss=0.1409


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.86it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=40.854


Train G2 E48: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 48  loss=0.1342


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.81it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=41.019


Train G2 E49: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 49  loss=0.1216


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.84it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=39.738


Train G2 E50: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


Grade 2 Epoch 50  loss=0.1456


Eval: 100%|██████████| 1/1 [00:00<00:00,  5.85it/s]


  Valid G2 Top1/2/3=0.667/0.667/1.000  AvgNLL=41.649

===== Curriculum Stage: Grade 3 =====


Train G3 E1: 100%|██████████| 29/29 [00:18<00:00,  1.58it/s]


Grade 3 Epoch 1  loss=1.6744


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.70it/s]


  Valid G3 Top1/2/3=0.200/0.267/0.667  AvgNLL=46.338


Train G3 E2: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 2  loss=1.6376


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


  Valid G3 Top1/2/3=0.200/0.267/0.600  AvgNLL=45.755


Train G3 E3: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 3  loss=1.6203


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


  Valid G3 Top1/2/3=0.200/0.333/0.733  AvgNLL=45.747


Train G3 E4: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 4  loss=1.5707


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.76it/s]


  Valid G3 Top1/2/3=0.267/0.333/0.667  AvgNLL=44.904


Train G3 E5: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 5  loss=1.5290


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


  Valid G3 Top1/2/3=0.333/0.400/0.667  AvgNLL=43.661


Train G3 E6: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 6  loss=1.4961


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


  Valid G3 Top1/2/3=0.400/0.467/0.733  AvgNLL=42.592


Train G3 E7: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 7  loss=1.3578


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


  Valid G3 Top1/2/3=0.267/0.400/0.533  AvgNLL=37.860


Train G3 E8: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 8  loss=1.2100


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


  Valid G3 Top1/2/3=0.333/0.533/0.667  AvgNLL=42.611


Train G3 E9: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 9  loss=1.2390


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


  Valid G3 Top1/2/3=0.267/0.400/0.667  AvgNLL=41.378


Train G3 E10: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 10  loss=1.0039


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


  Valid G3 Top1/2/3=0.467/0.600/0.600  AvgNLL=40.016


Train G3 E11: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 11  loss=0.8908


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.78it/s]


  Valid G3 Top1/2/3=0.267/0.533/0.600  AvgNLL=36.967


Train G3 E12: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 12  loss=0.8215


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


  Valid G3 Top1/2/3=0.333/0.467/0.667  AvgNLL=40.249


Train G3 E13: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 13  loss=0.6887


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


  Valid G3 Top1/2/3=0.200/0.467/0.733  AvgNLL=39.298


Train G3 E14: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 14  loss=0.6230


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


  Valid G3 Top1/2/3=0.333/0.467/0.667  AvgNLL=40.350


Train G3 E15: 100%|██████████| 29/29 [00:18<00:00,  1.59it/s]


Grade 3 Epoch 15  loss=0.5527


Eval: 100%|██████████| 4/4 [00:00<00:00,  4.77it/s]


  Valid G3 Top1/2/3=0.467/0.467/0.600  AvgNLL=41.803


Train G3 E16:  93%|█████████▎| 27/29 [00:17<00:01,  1.59it/s]

In [13]:
"""
Dot-Product Rank-Head Trainer

Trains a GPT-2 encoder + linear rank head.
Top-1/2/3 accuracy + average softmax entropy on dev/test splits.
"""

import argparse, os, random, sys
import numpy as np, pandas as pd, torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import GPT2Tokenizer
from models.gpt2 import GPT2Model

TQDM_DISABLE=False

# --------------------------------------------------------------------- #
def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False
# --------------------------------------------------------------------- #
class AnswerDataset(Dataset):
    def __init__(self, df, tok, max_len=512):
        self.df, self.tok, self.max_len = df.reset_index(drop=True), tok, max_len
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        r=self.df.iloc[idx]
        choices=[r[f"choice_{i}"] for i in range(1,6)]
        texts=[f"{r.prefix}{c}{r.suffix}" for c in choices]
        lab=int(r.answer)-1
        return texts, lab
    def collate_fn(self,b):
        txts, labs=zip(*b); flat=sum(txts,[])
        enc=self.tok(flat,padding="max_length",truncation=True,
                     max_length=self.max_len,return_tensors="pt")
        B=len(b)
        return {"input_ids":enc.input_ids.view(B,5,-1),
                "attention_mask":enc.attention_mask.view(B,5,-1),
                "labels":torch.tensor(labs)}
# --------------------------------------------------------------------- #
class DotRankModel(nn.Module):
    def __init__(self,args):
        super().__init__()
        self.enc=GPT2Model.from_pretrained(model=args.model_size,
                                           d=args.d,l=args.l,num_heads=args.num_heads)
        self.head=nn.Linear(args.d,1)
        for p in self.parameters(): p.requires_grad=True
    def forward(self,ids,mask):
        B,C,L=ids.shape
        flat_ids,flat_mask=ids.view(B*C,L),mask.view(B*C,L)
        h=self.enc(flat_ids,attention_mask=flat_mask)["last_hidden_state"]
        last_idx=flat_mask.sum(1)-1
        pooled=h[torch.arange(h.size(0)), last_idx]  # (B*C,d)
        score=self.head(pooled).view(B,C)            # (B,5)
        return score
# --------------------------------------------------------------------- #
def evaluate(loader,model,dev):
    model.eval(); tot=top1=top2=top3=0; tot_ent=0
    with torch.no_grad():
        for bt in tqdm(loader,disable=TQDM_DISABLE,desc="Eval"):
            s=model(bt["input_ids"].to(dev),bt["attention_mask"].to(dev))
            rank=s.argsort(dim=1,descending=True)
            y=bt["labels"].to(dev); B=y.size(0)
            top1+=(rank[:,0]==y).sum().item()
            top2+=((rank[:,:2]==y.unsqueeze(1)).any(1)).sum().item()
            top3+=((rank[:,:3]==y.unsqueeze(1)).any(1)).sum().item()
            tot+=B
            probs=F.softmax(s,dim=1)
            ent=-(probs*probs.log()).sum(1).mean().item()
            tot_ent+=ent*B
    return top1/tot,top2/tot,top3/tot,tot_ent/tot
# --------------------------------------------------------------------- #
def train(args):
    seed_everything(args.seed)
    device = torch.device("cuda" if args.use_gpu and torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # 1) CSV 한 번만 로드
    df = pd.read_csv(args.data_path, encoding="utf-8-sig")

    # 2) 토크나이저 & model/optimizer 초기화
    tok = GPT2Tokenizer.from_pretrained(args.model_size)
    tok.pad_token = tok.eos_token
    model = DotRankModel(args).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr)

    best_top1 = 0.0

    # 3) Curriculum: grade 2 → grade 3 순차 학습
    for grade in [2, 3]:
        print(f"\n=== Curriculum Stage: Grade {grade} ===")
        train_df = df[(df.split == "train") & (df.grade == grade)]
        valid_df = df[(df.split == "valid") & (df.grade == grade)]

        if len(train_df) == 0:
            print(f"No train data for grade {grade}, skipping.")
            continue

        # 4) DataLoader 생성
        tr_ds = AnswerDataset(train_df, tok, args.max_length)
        va_ds = AnswerDataset(valid_df, tok, args.max_length)
        tr_ld = DataLoader(tr_ds, batch_size=args.batch_size, shuffle=True, collate_fn=tr_ds.collate_fn)
        va_ld = DataLoader(va_ds, batch_size=args.batch_size, shuffle=False, collate_fn=va_ds.collate_fn)

        # 5) 기존 epoch 루프 재사용
        for ep in range(1, args.epochs+1):
            model.train()
            ep_loss = 0.0
            for bt in tqdm(tr_ld, desc=f"[G{grade}] Train Ep{ep}", disable=TQDM_DISABLE):
                opt.zero_grad()
                s = model(bt["input_ids"].to(device), bt["attention_mask"].to(device))
                loss = F.cross_entropy(s, bt["labels"].to(device))
                loss.backward()
                opt.step()
                ep_loss += loss.item()
            print(f"[G{grade}] Ep{ep} train_loss = {ep_loss/len(tr_ld):.4f}")

            # 6) Validation
            t1, t2, t3, ent = evaluate(va_ld, model, device)
            print(f"[G{grade}] Ep{ep} valid Top1/2/3 = {t1:.3f}/{t2:.3f}/{t3:.3f} | Ent={ent:.3f}")

            # 7) Checkpoint
            if t1 > best_top1:
                best_top1 = t1
                save(model, opt, args)

def save(m,o,args):
    os.makedirs(os.path.dirname(args.save_path),exist_ok=True)
    torch.save({"model":m.state_dict(),"optim":o.state_dict(),"args":args},args.save_path)
    print("Saved to",args.save_path)
def get_args():
    p=argparse.ArgumentParser()
    p.add_argument("--data_path",type=str,default="/content/drive/MyDrive/CSEG321/dataset/answer_margin.csv")
    p.add_argument("--model_size",type=str,default="gpt2",choices=["gpt2","gpt2-medium","gpt2-large"])
    p.add_argument("--batch_size",type=int,default=4)
    p.add_argument("--max_length",type=int,default=512)
    p.add_argument("--lr",type=float,default=5e-5)
    p.add_argument("--epochs",type=int,default=50)
    p.add_argument("--seed",type=int,default=42)
    p.add_argument("--use_gpu",action="store_true")
    # Jupyter friendly
    args = p.parse_args([])
    if torch.cuda.is_available(): args.use_gpu = True
    if args.model_size=="gpt2": args.d,args.l,args.num_heads=768,12,12
    elif args.model_size=="gpt2-medium": args.d,args.l,args.num_heads=1024,24,16
    else: args.d,args.l,args.num_heads=1280,36,20
    args.save_path=f"/content/drive/MyDrive/CSEG321/models/answer_dot_{args.model_size}_{args.epochs}e_curriculum.pt"
    return args
if __name__=="__main__": train(get_args())


Using device: cuda

=== Curriculum Stage: Grade 2 ===


[G2] Train Ep1: 100%|██████████| 7/7 [00:03<00:00,  2.07it/s]


[G2] Ep1 train_loss = 1.8211


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.29it/s]


[G2] Ep1 valid Top1/2/3 = 0.333/0.333/0.333 | Ent=1.601
Saved to /content/drive/MyDrive/CSEG321/models/answer_dot_gpt2_50e_curriculum.pt


[G2] Train Ep2: 100%|██████████| 7/7 [00:03<00:00,  2.09it/s]


[G2] Ep2 train_loss = 1.6105


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.50it/s]


[G2] Ep2 valid Top1/2/3 = 0.333/0.333/0.667 | Ent=1.607


[G2] Train Ep3: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep3 train_loss = 1.6059


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.51it/s]


[G2] Ep3 valid Top1/2/3 = 0.333/0.667/0.667 | Ent=1.607


[G2] Train Ep4: 100%|██████████| 7/7 [00:03<00:00,  2.11it/s]


[G2] Ep4 train_loss = 1.6466


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.48it/s]


[G2] Ep4 valid Top1/2/3 = 0.667/0.667/0.667 | Ent=1.608
Saved to /content/drive/MyDrive/CSEG321/models/answer_dot_gpt2_50e_curriculum.pt


[G2] Train Ep5: 100%|██████████| 7/7 [00:03<00:00,  2.11it/s]


[G2] Ep5 train_loss = 1.6094


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.50it/s]


[G2] Ep5 valid Top1/2/3 = 0.333/0.333/0.667 | Ent=1.608


[G2] Train Ep6: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep6 train_loss = 1.5632


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


[G2] Ep6 valid Top1/2/3 = 0.333/0.333/0.667 | Ent=1.607


[G2] Train Ep7: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep7 train_loss = 1.5557


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.49it/s]


[G2] Ep7 valid Top1/2/3 = 0.333/0.333/0.667 | Ent=1.604


[G2] Train Ep8: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep8 train_loss = 1.6459


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


[G2] Ep8 valid Top1/2/3 = 0.333/0.667/0.667 | Ent=1.605


[G2] Train Ep9: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep9 train_loss = 1.5269


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.52it/s]


[G2] Ep9 valid Top1/2/3 = 0.667/0.667/0.667 | Ent=1.595


[G2] Train Ep10: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep10 train_loss = 1.5071


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.48it/s]


[G2] Ep10 valid Top1/2/3 = 0.667/0.667/0.667 | Ent=1.584


[G2] Train Ep11: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep11 train_loss = 1.5205


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.51it/s]


[G2] Ep11 valid Top1/2/3 = 0.667/0.667/0.667 | Ent=1.557


[G2] Train Ep12: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep12 train_loss = 1.3866


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.50it/s]


[G2] Ep12 valid Top1/2/3 = 0.667/0.667/0.667 | Ent=1.512


[G2] Train Ep13: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep13 train_loss = 1.4275


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.52it/s]


[G2] Ep13 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=1.453


[G2] Train Ep14: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep14 train_loss = 1.5513


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.50it/s]


[G2] Ep14 valid Top1/2/3 = 0.000/0.333/0.333 | Ent=1.457


[G2] Train Ep15: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep15 train_loss = 1.3931


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.54it/s]


[G2] Ep15 valid Top1/2/3 = 0.333/0.333/0.333 | Ent=1.563


[G2] Train Ep16: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep16 train_loss = 1.3053


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.50it/s]


[G2] Ep16 valid Top1/2/3 = 0.333/0.333/0.333 | Ent=1.565


[G2] Train Ep17: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep17 train_loss = 1.2270


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


[G2] Ep17 valid Top1/2/3 = 0.333/0.333/0.333 | Ent=1.514


[G2] Train Ep18: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep18 train_loss = 1.0195


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.52it/s]


[G2] Ep18 valid Top1/2/3 = 0.000/0.333/0.333 | Ent=1.347


[G2] Train Ep19: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep19 train_loss = 0.8202


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.46it/s]


[G2] Ep19 valid Top1/2/3 = 0.333/0.333/0.667 | Ent=1.240


[G2] Train Ep20: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep20 train_loss = 0.5813


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.51it/s]


[G2] Ep20 valid Top1/2/3 = 0.333/0.667/0.667 | Ent=1.097


[G2] Train Ep21: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep21 train_loss = 0.5564


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.51it/s]


[G2] Ep21 valid Top1/2/3 = 0.333/0.667/0.667 | Ent=0.825


[G2] Train Ep22: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep22 train_loss = 0.4090


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.52it/s]


[G2] Ep22 valid Top1/2/3 = 0.333/0.667/0.667 | Ent=0.790


[G2] Train Ep23: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep23 train_loss = 0.4164


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.48it/s]


[G2] Ep23 valid Top1/2/3 = 0.667/0.667/0.667 | Ent=0.805


[G2] Train Ep24: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep24 train_loss = 0.2312


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.52it/s]


[G2] Ep24 valid Top1/2/3 = 0.333/0.667/0.667 | Ent=0.856


[G2] Train Ep25: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep25 train_loss = 0.3817


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


[G2] Ep25 valid Top1/2/3 = 0.333/0.667/0.667 | Ent=0.712


[G2] Train Ep26: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep26 train_loss = 0.1461


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.51it/s]


[G2] Ep26 valid Top1/2/3 = 0.333/0.667/0.667 | Ent=0.736


[G2] Train Ep27: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep27 train_loss = 0.2510


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.52it/s]


[G2] Ep27 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.992


[G2] Train Ep28: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep28 train_loss = 0.1710


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.49it/s]


[G2] Ep28 valid Top1/2/3 = 0.333/0.667/0.667 | Ent=0.979


[G2] Train Ep29: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep29 train_loss = 0.0712


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.49it/s]


[G2] Ep29 valid Top1/2/3 = 0.333/0.667/0.667 | Ent=1.145


[G2] Train Ep30: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep30 train_loss = 0.1096


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.48it/s]


[G2] Ep30 valid Top1/2/3 = 0.333/0.333/0.667 | Ent=1.059


[G2] Train Ep31: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep31 train_loss = 0.0229


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.50it/s]


[G2] Ep31 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=1.053


[G2] Train Ep32: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep32 train_loss = 0.0291


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.50it/s]


[G2] Ep32 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.970


[G2] Train Ep33: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep33 train_loss = 0.0038


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.47it/s]


[G2] Ep33 valid Top1/2/3 = 0.000/0.000/0.667 | Ent=0.956


[G2] Train Ep34: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep34 train_loss = 0.0055


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.51it/s]


[G2] Ep34 valid Top1/2/3 = 0.000/0.000/0.667 | Ent=0.949


[G2] Train Ep35: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep35 train_loss = 0.0104


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


[G2] Ep35 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.940


[G2] Train Ep36: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep36 train_loss = 0.0035


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.49it/s]


[G2] Ep36 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.949


[G2] Train Ep37: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep37 train_loss = 0.0020


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.52it/s]


[G2] Ep37 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.943


[G2] Train Ep38: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep38 train_loss = 0.0006


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.46it/s]


[G2] Ep38 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.941


[G2] Train Ep39: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep39 train_loss = 0.0017


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.50it/s]


[G2] Ep39 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.937


[G2] Train Ep40: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep40 train_loss = 0.0004


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.45it/s]


[G2] Ep40 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.934


[G2] Train Ep41: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep41 train_loss = 0.0016


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


[G2] Ep41 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.933


[G2] Train Ep42: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep42 train_loss = 0.0007


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.47it/s]


[G2] Ep42 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.934


[G2] Train Ep43: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep43 train_loss = 0.0003


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


[G2] Ep43 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.933


[G2] Train Ep44: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep44 train_loss = 0.0004


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.46it/s]


[G2] Ep44 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.933


[G2] Train Ep45: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep45 train_loss = 0.0011


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.51it/s]


[G2] Ep45 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.934


[G2] Train Ep46: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep46 train_loss = 0.0005


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.38it/s]


[G2] Ep46 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.934


[G2] Train Ep47: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep47 train_loss = 0.0002


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


[G2] Ep47 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.934


[G2] Train Ep48: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep48 train_loss = 0.0005


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


[G2] Ep48 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.932


[G2] Train Ep49: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep49 train_loss = 0.0002


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.49it/s]


[G2] Ep49 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.931


[G2] Train Ep50: 100%|██████████| 7/7 [00:03<00:00,  2.12it/s]


[G2] Ep50 train_loss = 0.0004


Eval: 100%|██████████| 1/1 [00:00<00:00,  7.53it/s]


[G2] Ep50 valid Top1/2/3 = 0.000/0.333/0.667 | Ent=0.931

=== Curriculum Stage: Grade 3 ===


[G3] Train Ep1: 100%|██████████| 29/29 [00:13<00:00,  2.10it/s]


[G3] Ep1 train_loss = 1.7475


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.02it/s]


[G3] Ep1 valid Top1/2/3 = 0.333/0.467/0.667 | Ent=1.609


[G3] Train Ep2: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep2 train_loss = 1.5771


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


[G3] Ep2 valid Top1/2/3 = 0.600/0.933/0.933 | Ent=1.608


[G3] Train Ep3: 100%|██████████| 29/29 [00:13<00:00,  2.13it/s]


[G3] Ep3 train_loss = 1.6196


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


[G3] Ep3 valid Top1/2/3 = 0.200/0.400/0.533 | Ent=1.606


[G3] Train Ep4: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep4 train_loss = 1.5764


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


[G3] Ep4 valid Top1/2/3 = 0.400/0.667/0.800 | Ent=1.608


[G3] Train Ep5: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep5 train_loss = 1.5650


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


[G3] Ep5 valid Top1/2/3 = 0.533/0.733/0.867 | Ent=1.605


[G3] Train Ep6: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep6 train_loss = 1.5042


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.10it/s]


[G3] Ep6 valid Top1/2/3 = 0.533/0.733/0.933 | Ent=1.595


[G3] Train Ep7: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep7 train_loss = 1.4173


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


[G3] Ep7 valid Top1/2/3 = 0.533/0.933/0.933 | Ent=1.590


[G3] Train Ep8: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep8 train_loss = 1.2806


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


[G3] Ep8 valid Top1/2/3 = 0.600/0.800/0.800 | Ent=1.575


[G3] Train Ep9: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep9 train_loss = 1.3068


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


[G3] Ep9 valid Top1/2/3 = 0.733/0.867/0.867 | Ent=1.586
Saved to /content/drive/MyDrive/CSEG321/models/answer_dot_gpt2_50e_curriculum.pt


[G3] Train Ep10: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep10 train_loss = 1.1981


Eval: 100%|██████████| 4/4 [00:00<00:00,  5.82it/s]


[G3] Ep10 valid Top1/2/3 = 0.267/0.733/0.867 | Ent=1.552


[G3] Train Ep11: 100%|██████████| 29/29 [00:13<00:00,  2.10it/s]


[G3] Ep11 train_loss = 1.1230


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


[G3] Ep11 valid Top1/2/3 = 0.533/0.867/0.933 | Ent=1.546


[G3] Train Ep12: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep12 train_loss = 1.2092


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


[G3] Ep12 valid Top1/2/3 = 0.667/0.867/0.933 | Ent=1.541


[G3] Train Ep13: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep13 train_loss = 1.0993


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


[G3] Ep13 valid Top1/2/3 = 0.733/0.800/0.867 | Ent=1.499


[G3] Train Ep14: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep14 train_loss = 1.0105


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.13it/s]


[G3] Ep14 valid Top1/2/3 = 0.667/0.867/0.867 | Ent=1.539


[G3] Train Ep15: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep15 train_loss = 0.8463


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


[G3] Ep15 valid Top1/2/3 = 0.667/0.733/0.867 | Ent=1.471


[G3] Train Ep16: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep16 train_loss = 0.8436


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


[G3] Ep16 valid Top1/2/3 = 0.600/0.733/0.800 | Ent=1.439


[G3] Train Ep17: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep17 train_loss = 0.7978


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


[G3] Ep17 valid Top1/2/3 = 0.600/0.733/0.867 | Ent=1.440


[G3] Train Ep18: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep18 train_loss = 0.7890


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


[G3] Ep18 valid Top1/2/3 = 0.533/0.800/0.800 | Ent=1.436


[G3] Train Ep19: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep19 train_loss = 0.5424


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


[G3] Ep19 valid Top1/2/3 = 0.467/0.800/0.867 | Ent=1.279


[G3] Train Ep20: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep20 train_loss = 0.5796


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


[G3] Ep20 valid Top1/2/3 = 0.467/0.800/0.867 | Ent=1.406


[G3] Train Ep21: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep21 train_loss = 0.4428


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


[G3] Ep21 valid Top1/2/3 = 0.533/0.933/0.933 | Ent=1.323


[G3] Train Ep22: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep22 train_loss = 0.3986


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


[G3] Ep22 valid Top1/2/3 = 0.467/0.867/0.933 | Ent=1.334


[G3] Train Ep23: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep23 train_loss = 0.3178


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.14it/s]


[G3] Ep23 valid Top1/2/3 = 0.467/0.867/0.933 | Ent=1.343


[G3] Train Ep24: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep24 train_loss = 0.3377


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.13it/s]


[G3] Ep24 valid Top1/2/3 = 0.333/0.800/0.867 | Ent=1.315


[G3] Train Ep25: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep25 train_loss = 0.3842


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


[G3] Ep25 valid Top1/2/3 = 0.600/0.800/0.867 | Ent=1.352


[G3] Train Ep26: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep26 train_loss = 0.2547


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


[G3] Ep26 valid Top1/2/3 = 0.600/0.800/0.933 | Ent=1.268


[G3] Train Ep27: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep27 train_loss = 0.1675


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


[G3] Ep27 valid Top1/2/3 = 0.667/0.800/0.933 | Ent=1.285


[G3] Train Ep28: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep28 train_loss = 0.1656


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


[G3] Ep28 valid Top1/2/3 = 0.467/0.733/0.867 | Ent=1.219


[G3] Train Ep29: 100%|██████████| 29/29 [00:13<00:00,  2.13it/s]


[G3] Ep29 train_loss = 0.1058


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.11it/s]


[G3] Ep29 valid Top1/2/3 = 0.400/0.667/0.800 | Ent=1.471


[G3] Train Ep30: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep30 train_loss = 0.2067


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


[G3] Ep30 valid Top1/2/3 = 0.333/0.733/0.800 | Ent=1.341


[G3] Train Ep31: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep31 train_loss = 0.1203


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.18it/s]


[G3] Ep31 valid Top1/2/3 = 0.333/0.667/0.800 | Ent=1.304


[G3] Train Ep32: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep32 train_loss = 0.0942


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.12it/s]


[G3] Ep32 valid Top1/2/3 = 0.400/0.600/0.867 | Ent=1.196


[G3] Train Ep33: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep33 train_loss = 0.1865


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


[G3] Ep33 valid Top1/2/3 = 0.467/0.733/0.800 | Ent=1.283


[G3] Train Ep34: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep34 train_loss = 0.0884


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


[G3] Ep34 valid Top1/2/3 = 0.467/0.733/0.800 | Ent=1.211


[G3] Train Ep35: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep35 train_loss = 0.0645


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


[G3] Ep35 valid Top1/2/3 = 0.467/0.733/0.800 | Ent=1.233


[G3] Train Ep36: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep36 train_loss = 0.0308


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


[G3] Ep36 valid Top1/2/3 = 0.200/0.733/0.800 | Ent=1.247


[G3] Train Ep37: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep37 train_loss = 0.0230


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


[G3] Ep37 valid Top1/2/3 = 0.333/0.733/0.733 | Ent=1.223


[G3] Train Ep38: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep38 train_loss = 0.0128


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


[G3] Ep38 valid Top1/2/3 = 0.400/0.800/0.867 | Ent=1.193


[G3] Train Ep39: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep39 train_loss = 0.0328


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.13it/s]


[G3] Ep39 valid Top1/2/3 = 0.333/0.733/0.800 | Ent=1.207


[G3] Train Ep40: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep40 train_loss = 0.0244


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]


[G3] Ep40 valid Top1/2/3 = 0.400/0.800/0.933 | Ent=1.157


[G3] Train Ep41: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep41 train_loss = 0.0290


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


[G3] Ep41 valid Top1/2/3 = 0.267/0.667/0.933 | Ent=1.191


[G3] Train Ep42: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep42 train_loss = 0.0810


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.15it/s]


[G3] Ep42 valid Top1/2/3 = 0.333/0.667/0.867 | Ent=1.190


[G3] Train Ep43: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep43 train_loss = 0.0268


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.16it/s]


[G3] Ep43 valid Top1/2/3 = 0.267/0.733/0.800 | Ent=1.187


[G3] Train Ep44: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep44 train_loss = 0.0381


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.13it/s]


[G3] Ep44 valid Top1/2/3 = 0.267/0.733/0.867 | Ent=1.109


[G3] Train Ep45: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep45 train_loss = 0.0235


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.14it/s]


[G3] Ep45 valid Top1/2/3 = 0.267/0.733/0.800 | Ent=1.074


[G3] Train Ep46: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep46 train_loss = 0.0397


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.13it/s]


[G3] Ep46 valid Top1/2/3 = 0.267/0.733/0.800 | Ent=1.107


[G3] Train Ep47: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep47 train_loss = 0.0719


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.14it/s]


[G3] Ep47 valid Top1/2/3 = 0.200/0.667/0.867 | Ent=1.148


[G3] Train Ep48: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep48 train_loss = 0.0198


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.14it/s]


[G3] Ep48 valid Top1/2/3 = 0.200/0.667/0.733 | Ent=1.187


[G3] Train Ep49: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep49 train_loss = 0.0564


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.14it/s]


[G3] Ep49 valid Top1/2/3 = 0.133/0.667/0.867 | Ent=1.128


[G3] Train Ep50: 100%|██████████| 29/29 [00:13<00:00,  2.12it/s]


[G3] Ep50 train_loss = 0.0323


Eval: 100%|██████████| 4/4 [00:00<00:00,  6.16it/s]

[G3] Ep50 valid Top1/2/3 = 0.267/0.667/0.933 | Ent=1.106



