# Step 1: Preparing

In [None]:
import json
import torch

import sys
import random
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import types
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup,
    DataCollatorWithPadding
)
from peft import get_peft_model, LoraConfig, TaskType
from tqdm import tqdm
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
# Create dataloaders
from transformers import DataCollatorWithPadding
from torch.utils.tensorboard import SummaryWriter
MODEL_NAME = "Qwen/Qwen3-0.6B"



# Construct path relative to project root
_cwd = os.getcwd()
if os.path.basename(_cwd) == 'scripts':
    # If we're in scripts folder, go up one level
    DATASET_PATH = os.path.join(os.path.dirname(_cwd), "data", "verifier_dataset_train.json")
elif os.path.basename(_cwd) == 'notebooks':
    # If we're in notebooks folder, go up two levels
    DATASET_PATH = os.path.join(os.path.dirname(os.path.dirname(_cwd)), "data", "verifier_dataset_train.json")
else:
    # If we're in project root
    DATASET_PATH = os.path.join(_cwd, "data", "verifier_dataset_train.json")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Running on device: {device}")

# Step 2: Tokenizer and dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

class Adaptive_N_VerifierDataset(Dataset):
    def __init__(self, raw_data_list, tokenizer, max_length=512):
        self.samples = []
        for entry in raw_data_list:
            question = entry['question']
            correct_answers_num = entry['correct_answers_num']
            total_answers_num = entry['total_answers_num']

            # Compute empirical probability (correct rate)
            empirical_p = correct_answers_num / total_answers_num if total_answers_num > 0 else 0.0

            # Only use the question, no answer/solution
            text = f"Question: {question}"

            # Label is the empirical probability (float between 0.0 and 1.0)
            self.samples.append({"text": text, "label": float(empirical_p)})

        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        encodings = self.tokenizer(
            item["text"],
            truncation=True,
            max_length=self.max_length,
            padding=False
        )
        return {
            "input_ids": encodings["input_ids"],
            "attention_mask": encodings["attention_mask"],
            "labels": torch.tensor(item["label"], dtype=torch.float)
        }

class VerifierDataset(Dataset):
    def __init__(self, raw_data_list, tokenizer, max_length=512):
        self.samples = []
        for entry in raw_data_list:
            question = entry['question']
            answers = entry['answers']
            labels = entry['answer_labels']
            ref_answer = entry["reference_answer"]

            # If a reference answer exists, append it to the end and add the corresponding label 1
            if ref_answer is not None:
                answers = answers + [ref_answer]
                labels = labels + [1]
            for ans, label in zip(answers, labels):
                text = f"Question: {question}\nAnswer: {ans}"
                # Label must be float for BCE Loss (0.0 or 1.0)
                self.samples.append({"text": text, "label": float(label)})
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        encodings = self.tokenizer(
            item["text"],
            truncation=True,
            max_length=self.max_length,
            padding=False
        )
        return {
            "input_ids": encodings["input_ids"],
            "attention_mask": encodings["attention_mask"],
            "labels": torch.tensor(item["label"], dtype=torch.float)
        }


In [None]:
# difficultity classification dataset
def probability_to_difficulty_class(prob, num_classes):
    """
    Map a probability value to a difficulty class.

    Args:
        prob: float between 0.0 and 1.0 (empirical correct rate)
        num_classes: int, number of difficulty levels

    Returns:
        int, class index from 0 to num_classes-1

    Examples:
        num_classes=2: [0.0, 0.5) → 0, [0.5, 1.0] → 1
        num_classes=3: [0.0, 0.33) → 0, [0.33, 0.67) → 1, [0.67, 1.0] → 2
        num_classes=4: [0.0, 0.25) → 0, [0.25, 0.5) → 1, [0.5, 0.75) → 2, [0.75, 1.0] → 3
    """
    if prob < 0.0 or prob > 1.0:
        raise ValueError(f"Probability must be between 0.0 and 1.0, got {prob}")

    # Edge case: prob = 1.0 should map to the highest class
    if prob == 1.0:
        return num_classes - 1

    # Map [0, 1) to class indices [0, num_classes-1]
    class_idx = int(prob * num_classes)
    return class_idx


class DifficultyClassificationDataset(Dataset):
    """
    Dataset for multi-class difficulty classification.
    Converts continuous probability labels into discrete difficulty classes.
    """
    def __init__(self, raw_data_list, tokenizer, num_classes=2, max_length=512):
        """
        Args:
            raw_data_list: List of dicts with 'question', 'correct_answers_num', 'total_answers_num'
            tokenizer: Tokenizer instance
            num_classes: Number of difficulty levels (default: 2 for easy/hard)
            max_length: Max sequence length
        """
        self.samples = []
        self.num_classes = num_classes

        for entry in raw_data_list:
            question = entry['question']
            correct_answers_num = entry['correct_answers_num']
            total_answers_num = entry['total_answers_num']

            # Compute empirical probability (correct rate)
            empirical_p = correct_answers_num / total_answers_num if total_answers_num > 0 else 0.0

            # Convert probability to difficulty class
            difficulty_class = probability_to_difficulty_class(empirical_p, num_classes)

            # Only use the question, no answer/solution
            text = f"Question: {question}"

            # Store both the class and the original probability for reference
            self.samples.append({
                "text": text,
                "label": difficulty_class,
                "prob": empirical_p  # Keep original for debugging/analysis
            })

        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        encodings = self.tokenizer(
            item["text"],
            truncation=True,
            max_length=self.max_length,
            padding=False
        )
        return {
            "input_ids": encodings["input_ids"],
            "attention_mask": encodings["attention_mask"],
            "labels": torch.tensor(item["label"], dtype=torch.long)  # long for CE loss
        }

# Step 3: Two-Head Model Definition

In [None]:
# ==============================================================================
# Two-Head Model Implementation
# ==============================================================================

print("Loading base Qwen3Model...")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Running on device: {device}")

# Load ONLY the base transformer (Qwen3Model), not the classifier
base_model = AutoModel.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map=device
)
base_model.config.pad_token_id = tokenizer.pad_token_id

# Apply LoRA to base model
print("Applying LoRA...")
peft_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"],
)
base_model = get_peft_model(base_model, peft_config)
base_model.print_trainable_parameters()

# Get hidden size
hidden_size = base_model.config.hidden_size

# Configuration for head_b
NUM_CLASSES = None  # Set to None for probability regression, or n for n-class classification

# Manually create two heads
head_dtype = base_model.dtype
head_a = nn.Linear(hidden_size, 1, bias=False).to(device, dtype=head_dtype)

# # Head B: one layer probability regression (1 output) or multi-class classification (n outputs)
# if NUM_CLASSES is None:
#     head_b = nn.Linear(hidden_size, 1, bias=True).to(device, dtype=head_dtype)
#     head_b_mode = "Probability Regression"
# else:
#     head_b = nn.Linear(hidden_size, NUM_CLASSES, bias=True).to(device, dtype=head_dtype)
#     head_b_mode = f"{NUM_CLASSES}-Class Classification"
# Head B: probability regression (1 output) or multi-class classification (n outputs)
# print(f"\nTwo heads created:")
# print(f"  Head A (Binary Classification): {head_a}")
# print(f"  Head B ({head_b_mode}): {head_b}")
# print(f"  NUM_CLASSES = {NUM_CLASSES}")
if NUM_CLASSES is None:
    output_dim = 1
    head_b_mode = "Probability Regression"
else:
    output_dim = NUM_CLASSES
    head_b_mode = f"{NUM_CLASSES}-Class Classification"

# Two-layer classifier head with GELU activation and dropout
head_b = nn.Sequential(
    nn.Linear(hidden_size, hidden_size, bias=True),
    nn.GELU(),
    nn.LayerNorm(hidden_size),
    nn.Dropout(0.1),
    nn.Linear(hidden_size, output_dim, bias=True)
).to(device, dtype=head_dtype)

print(f"\nTwo heads created:")
print(f"  Head A (Binary Classification): {head_a}")
print(f"  Head B ({head_b_mode}): 2-layer MLP")
print(f"    - Layer 1: Linear({hidden_size}, {hidden_size}) + GELU + Dropout(0.1)")
print(f"    - Layer 2: Linear({hidden_size}, {output_dim})")
print(f"  NUM_CLASSES = {NUM_CLASSES}")



# ==============================================================================
# Define forward function
# ==============================================================================
def two_head_forward(input_ids, attention_mask, head='both'):
    """
    Forward pass for two-head model.

    Args:
        input_ids: Token IDs
        attention_mask: Attention mask
        head: 'a', 'b', or 'both'

    Returns:
        SimpleNamespace with .logits (and .logits_b for 'both')
        - Head A: (batch, 1) for binary classification
        - Head B: (batch, 1) if NUM_CLASSES is None (probability regression)
                  (batch, NUM_CLASSES) if NUM_CLASSES is set (multi-class classification)
    """
    from types import SimpleNamespace

    # Get base model outputs
    outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)

    # Get last hidden state
    last_hidden = outputs.last_hidden_state  # (batch, seq_len, hidden_size)

    # Mean pooling with masking
    if attention_mask is not None:
        mask = attention_mask.unsqueeze(-1).float()  # (batch, seq_len, 1)
        masked_hidden = last_hidden * mask
        summed = masked_hidden.sum(dim=1)  # (batch, hidden_size)
        counts = mask.sum(dim=1).clamp(min=1e-9)  # (batch, 1)
        pooled = summed / counts
    else:
        pooled = last_hidden.mean(dim=1)  # (batch, hidden_size)

    # # Last-token pooling (respect attention mask)
    # seq_lens = attention_mask.sum(dim=1) - 1  # index of last non-padding token
    # seq_lens = seq_lens.clamp(min=0)
    # last_tokens = last_hidden[torch.arange(last_hidden.size(0)), seq_lens]  # (batch, hidden_size)
    # pooled = last_tokens
    pooled = pooled.to(head_a.weight.dtype)
    # Apply heads
    logits_a = head_a(pooled)  # (batch, 1)
    logits_b = head_b(pooled)  # (batch, 1)

    # Return based on which head(s) requested
    if head == 'a':
        return SimpleNamespace(logits=logits_a)
    elif head == 'b':
        return SimpleNamespace(logits=logits_b)
    else:  # 'both'
        return SimpleNamespace(logits=logits_a, logits_b=logits_b)

print("\nTwo-head model ready!")
print("Usage: two_head_forward(input_ids, attention_mask, head='a'/'b'/'both')")
print(f"Note: Change NUM_CLASSES to switch between regression (None) and classification (n)")


In [None]:
# ==============================================================================
# Test Two-Head Model
# ==============================================================================
print("Testing Two-Head Model")
print("="*60)

# Create test input
test_text = "Question: What is 2+2?"
test_inputs = tokenizer(test_text, return_tensors="pt", truncation=True, max_length=256)
test_inputs = {k: v.to(device) for k, v in test_inputs.items()}

# Test Head A (Binary Classification)
print("\n1. Head A (Binary Classification):")
outputs_a = two_head_forward(**test_inputs, head='a')
prob_a = torch.sigmoid(outputs_a.logits)
print(f"   Logit: {outputs_a.logits.item():.4f}")
print(f"   Probability: {prob_a.item():.4f}")

# Test Head B (Probability Regression or Classification)
if NUM_CLASSES is None:
    print("\n2. Head B (Probability Regression):")
    outputs_b = two_head_forward(**test_inputs, head='b')
    prob_b = torch.sigmoid(outputs_b.logits)
    print(f"   Logit: {outputs_b.logits.item():.4f}")
    print(f"   Probability: {prob_b.item():.4f}")
else:
    print(f"\n2. Head B ({NUM_CLASSES}-Class Classification):")
    outputs_b = two_head_forward(**test_inputs, head='b')
    print(f"   Logits shape: {outputs_b.logits.shape}")
    print(f"   Logits: {outputs_b.logits.squeeze()}")
    probs = torch.softmax(outputs_b.logits, dim=-1)
    print(f"   Class probabilities: {probs.squeeze()}")
    print(f"   Predicted class: {probs.argmax(dim=-1).item()}")

# Test Both Heads
print("\n3. Both Heads:")
outputs_both = two_head_forward(**test_inputs, head='both')
print(f"   Head A Logit: {outputs_both.logits.item():.4f}")
if NUM_CLASSES is None:
    print(f"   Head B Logit: {outputs_both.logits_b.item():.4f}")
else:
    print(f"   Head B Logits: {outputs_both.logits_b.squeeze()}")

print("\n" + "="*60)
print("✓ Test Complete! Two-head model working correctly.")


# Step 4: (Option1) Two Stage Training

In [None]:
# ==============================================================================
# Stage headB, p prediction: Train LoRA + Head B (Freeze Head A)
# ==============================================================================
# Goal: Train the model to predict question difficulty (correct rate)
# ==============================================================================
EPOCHS = 10
LEARNING_RATE = 1e-3 #5e-4
DEBUG_SAMPLE_SIZE = None # Set to None for full run
BATCH_SIZE = 2

print("Stage B: Training LoRA + Head B (Head A frozen)")
print("="*60)

# Freeze Head A
for param in head_a.parameters():
    param.requires_grad = False
print("Head A frozen")

# Load dataset
print("\nLoading dataset...")
if not os.path.exists(DATASET_PATH):
    print(f"Error: {DATASET_PATH} not found")
else:
    with open(DATASET_PATH, 'r', encoding='utf-8') as f:
        raw_questions = json.load(f)

    if DEBUG_SAMPLE_SIZE:
        raw_questions = raw_questions[:DEBUG_SAMPLE_SIZE]

    random.seed(42)
    random.shuffle(raw_questions)

    # 90/10 split
    split_idx = int(0.9 * len(raw_questions))
    train_questions = raw_questions[:split_idx]
    val_questions = raw_questions[split_idx:]

    # Create datasets for Head B (question only, predict correct rate)
    train_dataset = Adaptive_N_VerifierDataset(train_questions, tokenizer, max_length=256)
    val_dataset = Adaptive_N_VerifierDataset(val_questions, tokenizer, max_length=256)

    print(f" Dataset loaded: {len(train_dataset)} train, {len(val_dataset)} val samples")

collator = DataCollatorWithPadding(tokenizer=tokenizer)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator)

# Set up optimizer (only LoRA + head_b parameters)
trainable_params = [p for p in base_model.parameters() if p.requires_grad] + \
                     [p for p in head_b.parameters() if p.requires_grad]
optimizer = AdamW(trainable_params, lr=LEARNING_RATE)

# Learning rate scheduler
num_training_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * num_training_steps),
    num_training_steps=num_training_steps
)

# Loss function: BCE for comparing two probabilities (prediction vs empirical)
loss_fn = nn.BCEWithLogitsLoss()

print(f"\n Optimizer and scheduler ready")
print(f"  Training params: LoRA + Head B")
print(f"  Loss function: BCEWithLogitsLoss (cross-entropy)")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")

In [None]:
# ==============================================================================
# Training Loop - Stage B
# ==============================================================================

print("\n" + "="*60)
print("Starting Training...")
print("="*60)

OUTPUT_DIR_STAGEB = "../../outputs/two_head_stageB"
os.makedirs(OUTPUT_DIR_STAGEB, exist_ok=True)

for epoch in range(EPOCHS):
    # Training phase
    base_model.train()
    head_b.train()
    total_train_loss = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for batch in progress_bar:
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # Forward pass (Head B only)
        outputs = two_head_forward(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            head='b'
        )

        # Calculate loss
        loss = loss_fn(outputs.logits.squeeze(-1), batch['labels'].float())

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Track loss
        total_train_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_train_loss = total_train_loss / len(train_loader)

    # Validation phase
    base_model.eval()
    head_b.eval()
    total_val_loss = 0

    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = two_head_forward(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                head='b'
            )

            loss = loss_fn(outputs.logits.squeeze(-1), batch['labels'])
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader)

    print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'base_model_state_dict': base_model.state_dict(),
        'head_b_state_dict': head_b.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
    }
    if (epoch + 1) % 3 ==0:
        torch.save(checkpoint, os.path.join(OUTPUT_DIR_STAGEB, f'checkpoint_epoch{epoch+1}.pt'))

print("\n" + "="*60)
print(" Stage B Training Complete!")
print(f"Models saved to: {OUTPUT_DIR_STAGEB}")
print("="*60)


In [None]:
# ==============================================================================
# Test Stage B Model
# ==============================================================================

print("\nTesting trained Head B...")
print("="*60)

base_model.eval()
head_b.eval()

# Test with a few examples
test_questions = [
    "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
    "Josh decides to try flipping a house.  He buys a house for $80,000 and then puts in $50,000 in repairs.  This increased the value of the house by 150%.  How much profit did he make?",
    "A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bolts in total does it take?",
    "Every day, Wendi feeds each of her chickens three cups of mixed chicken feed, containing seeds, mealworms and vegetables to help keep them healthy.  She gives the chickens their feed in three separate meals. In the morning, she gives her flock of chickens 15 cups of feed.  In the afternoon, she gives her chickens another 25 cups of feed.  How many cups of feed does she need to give her chickens in the final meal of the day if the size of Wendi's flock is 20 chickens?"
]

with torch.no_grad():
    for q in test_questions:
        text = f"Question: {q}"
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        outputs = two_head_forward(**inputs, head='b')
        predicted_prob = torch.sigmoid(outputs.logits).item()

        print(f"\nQuestion: {q}")
        print(f"  Predicted difficulty (probability): {predicted_prob:.4f}")


In [None]:
# Validation comparison

base_model.eval()
head_b.eval()
print("\nValidating on 5 examples from val set")
print("="*60)
val_entries = val_questions[:20]  # replace with however you load your validation split
for entry in val_entries:
    text = f"Question: {entry['question']}"
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256).to(device)
    with torch.no_grad():
        logits_b = two_head_forward(**inputs, head='b').logits
    pred_prob = torch.sigmoid(logits_b).item()
    true_prob = entry['correct_answers_num'] / max(entry['total_answers_num'], 1)
    print(f"\nQuestion: {entry['question']}")
    print(f"  Predicted difficulty: {pred_prob:.4f}")
    print(f"  True difficulty:      {true_prob:.4f}")

 # (Option2) Classification

In [None]:

# Stage: Multi-Class Difficulty Classification Training Setup
# ==============================================================================
# Goal: Train the model to classify questions into N difficulty levels
# ==============================================================================

# Configuration
NUM_CLASSES = 4  # Number of difficulty levels (2, 3, 4, etc.)
EPOCHS = 5
LEARNING_RATE = 1e-3
BATCH_SIZE = 2
DEBUG_SAMPLE_SIZE = None # Set to None for full dataset

print(f"Classification Training Setup: {NUM_CLASSES}-Class Difficulty")
print("="*60)

# Freeze Head A (we're only training Head B for classification)
for param in head_a.parameters():
    param.requires_grad = False
print("✓ Head A frozen")

# Load dataset
print("\n✓ Loading dataset...")
if not os.path.exists(DATASET_PATH):
    raise FileNotFoundError(f"Dataset not found: {DATASET_PATH}")

with open(DATASET_PATH, 'r', encoding='utf-8') as f:
    raw_questions = json.load(f)

if DEBUG_SAMPLE_SIZE:
    raw_questions = raw_questions[:DEBUG_SAMPLE_SIZE]
    print(f"  Using {DEBUG_SAMPLE_SIZE} samples for debugging")

# Shuffle and split (90/10)
random.seed(42)
random.shuffle(raw_questions)
split_idx = int(0.9 * len(raw_questions))
train_questions = raw_questions[:split_idx]
val_questions = raw_questions[split_idx:]

# Create classification datasets
train_dataset = DifficultyClassificationDataset(
    train_questions,
    tokenizer,
    num_classes=NUM_CLASSES,
    max_length=256
)
val_dataset = DifficultyClassificationDataset(
    val_questions,
    tokenizer,
    num_classes=NUM_CLASSES,
    max_length=256
)

print(f"✓ Dataset loaded: {len(train_dataset)} train, {len(val_dataset)} val samples")

# Create dataloaders
collator = DataCollatorWithPadding(tokenizer=tokenizer)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator)

print(f"✓ DataLoaders ready: {len(train_loader)} train batches, {len(val_loader)} val batches")

# Set up optimizer (LoRA + head_b parameters only)
trainable_params = list(base_model.parameters()) + list(head_b.parameters())
optimizer = AdamW(trainable_params, lr=LEARNING_RATE)

# Learning rate scheduler
num_training_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * num_training_steps),
    num_training_steps=num_training_steps
)

# Loss function: CrossEntropy for multi-class classification
loss_fn = nn.CrossEntropyLoss()

print(f"\n✓ Training setup complete:")
print(f"  Model: LoRA + Head B ({NUM_CLASSES} classes)")
print(f"  Loss: CrossEntropyLoss")
print(f"  Optimizer: AdamW (lr={LEARNING_RATE})")
print(f"  Scheduler: Linear warmup (10% steps)")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print("="*60)

In [None]:
# ==============================================================================
# Training Loop: Multi-Class Difficulty Classification
# ==============================================================================

print("\nStarting training...")
print("="*60)

best_val_acc = 0.0
train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(EPOCHS):
    # ==================== Training ====================
    base_model.train()
    head_b.train()

    epoch_loss = 0.0
    train_correct = 0
    train_total = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)  # (batch,) with class indices

        # Forward pass
        outputs = two_head_forward(input_ids, attention_mask, head='b')
        logits = outputs.logits.squeeze(-1) if NUM_CLASSES is None else outputs.logits  # (batch, num_classes)

        # Calculate loss
        loss = loss_fn(logits, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Track metrics
        epoch_loss += loss.item()
        preds = logits.argmax(dim=-1)
        train_correct += (preds == labels).sum().item()
        train_total += labels.size(0)

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{train_correct/train_total:.4f}'
        })

    avg_train_loss = epoch_loss / len(train_loader)
    train_acc = train_correct / train_total
    train_losses.append(avg_train_loss)

    # ==================== Validation ====================
    base_model.eval()
    head_b.eval()

    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]  ", leave=False):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = two_head_forward(input_ids, attention_mask, head='b')
            logits = outputs.logits.squeeze(-1) if NUM_CLASSES is None else outputs.logits

            # Calculate loss
            loss = loss_fn(logits, labels)
            val_loss += loss.item()

            # Track accuracy
            preds = logits.argmax(dim=-1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    avg_val_loss = val_loss / len(val_loader)
    val_acc = val_correct / val_total
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_acc)

    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_acc:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        print(f"  ✓ New best validation accuracy: {best_val_acc:.4f}")

    print("-"*60)

print("\n" + "="*60)
print("Training Complete!")
print(f"Best Validation Accuracy: {best_val_acc:.4f}")
print("="*60)

 # (Option3) Mixed Training

In [None]:
# ==============================================================================
# Mixed Training: Alternating Head A and Head B
# ==============================================================================

import random
import numpy as np

# Configuration
# epoch_ratios = [0.7, 0.5, 0.3, 0.2]  # Ratio of Head A batches per epoch
epoch_ratios = [0, 0, 0, 0]  # Ratio of Head A batches per epoch
BATCH_SIZE = 20
DEBUG_SAMPLE_SIZE = 1000
OUTPUT_DIR_MIXED = "../../outputs/two_head_mixed"
os.makedirs(OUTPUT_DIR_MIXED, exist_ok=True)

print("="*60)
print("Mixed Training Setup")
print("="*60)

# Load dataset
print("\nLoading dataset...")

with open(DATASET_PATH, 'r', encoding='utf-8') as f:
    raw_questions = json.load(f)


random.seed(42)
random.shuffle(raw_questions)
split_idx = int(0.9 * len(raw_questions))
if DEBUG_SAMPLE_SIZE:
    raw_questions = raw_questions[:DEBUG_SAMPLE_SIZE]
    split_idx = int(0.9 * len(raw_questions))

train_questions = raw_questions[:split_idx]
val_questions = raw_questions[split_idx:]


# Create both datasets
dataset_a = VerifierDataset(train_questions, tokenizer, max_length=256)  # Head A: binary
dataset_b = Adaptive_N_VerifierDataset(train_questions, tokenizer, max_length=256)  # Head B: regression

val_dataset_a = VerifierDataset(val_questions, tokenizer, max_length=256)
val_dataset_b = Adaptive_N_VerifierDataset(val_questions, tokenizer, max_length=256)

print(f"Head A dataset: {len(dataset_a)} train, {len(val_dataset_a)} val")
print(f"Head B dataset: {len(dataset_b)} train, {len(val_dataset_b)} val")

# Create dataloaders
collator = DataCollatorWithPadding(tokenizer=tokenizer)
loader_a = DataLoader(dataset_a, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator)
loader_b = DataLoader(dataset_b, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator)
val_loader_a = DataLoader(val_dataset_a, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator)
val_loader_b = DataLoader(val_dataset_b, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator)

# Optimizer and scheduler
trainable_params = list(base_model.parameters()) + list(head_a.parameters()) + list(head_b.parameters())
optimizer = AdamW(trainable_params, lr=LEARNING_RATE)

total_batches = max(len(loader_a), len(loader_b)) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_batches),
    num_training_steps=total_batches
)

# Loss function
loss_fn = nn.BCEWithLogitsLoss()

print(f"\n Setup complete")
print(f"  Epochs: {EPOCHS}")
print(f"  Epoch ratios (Head A): {epoch_ratios}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")

# TensorBoard writer
import time
import os

run_name = time.strftime("%Y%m%d-%H%M%S")   # e.g., "20250204-153232"
log_dir = os.path.join(OUTPUT_DIR_MIXED, "logs", run_name)

writer = SummaryWriter(log_dir=log_dir)

print(f"TensorBoard logs saved to: {log_dir}")


In [None]:
# ==============================================================================
# Mixed Training Loop
# ==============================================================================

print("\n" + "="*60)
print("Starting Mixed Training")
print("="*60)

for epoch in range(EPOCHS):
    # Get ratio for this epoch
    ratio_a = epoch_ratios[epoch] if epoch < len(epoch_ratios) else epoch_ratios[-1]

    print(f"\nEpoch {epoch+1}/{EPOCHS} - Head A ratio: {ratio_a:.2f}")

    # Training phase
    base_model.train()
    head_a.train()
    head_b.train()

    # Create iterators
    iter_a = iter(loader_a)
    iter_b = iter(loader_b)
    total_batches = max(len(loader_a), len(loader_b))

    total_loss_a = 0
    total_loss_b = 0
    count_a = 0
    count_b = 0

    progress_bar = tqdm(range(total_batches), desc=f"Epoch {epoch+1}")
    global_step = epoch * total_batches  # Global step counter
    for _ in progress_bar:
        # Decide which head to train this batch

        use_head_a = random.random() < ratio_a

        if use_head_a:
            # Train Head A (freeze Head B)
            try:
                batch = next(iter_a)
            except StopIteration:
                iter_a = iter(loader_a)
                batch = next(iter_a)

            # Freeze Head B parameters
            for param in head_b.parameters():
                param.requires_grad = False
            for param in head_a.parameters():
                param.requires_grad = True

            # Move to device and forward
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = two_head_forward(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                head='a'
            )
            loss = loss_fn(outputs.logits.squeeze(-1), batch['labels'])

            total_loss_a += loss.item()
            count_a += 1
            writer.add_scalar('Loss/batch_a', loss.item(), global_step)
            global_step += 1

        else:
            # Train Head B (freeze Head A)
            try:
                batch = next(iter_b)
            except StopIteration:
                iter_b = iter(loader_b)
                batch = next(iter_b)

            # Freeze Head A parameters
            for param in head_a.parameters():
                param.requires_grad = False
            for param in head_b.parameters():
                param.requires_grad = True

            # Move to device and forward
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = two_head_forward(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                head='b'
            )
            loss = loss_fn(outputs.logits.squeeze(-1), batch['labels'])

            total_loss_b += loss.item()
            count_b += 1
            writer.add_scalar('Loss/batch_b', loss.item(), global_step)
            global_step += 1

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Update progress
        progress_bar.set_postfix({
            'loss_a': f'{total_loss_a/max(count_a,1):.4f}',
            'loss_b': f'{total_loss_b/max(count_b,1):.4f}',
            'n_a': count_a,
            'n_b': count_b
        })

    avg_loss_a = total_loss_a / max(count_a, 1)
    avg_loss_b = total_loss_b / max(count_b, 1)

    # Validation phase
    base_model.eval()
    head_a.eval()
    head_b.eval()

    # Unfreeze both for validation (no grad anyway)
    for param in head_a.parameters():
        param.requires_grad = True
    for param in head_b.parameters():
        param.requires_grad = True

    val_loss_a = 0
    val_loss_b = 0

    with torch.no_grad():
        for batch in val_loader_a:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = two_head_forward(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                head='a'
            )
            loss = loss_fn(outputs.logits.squeeze(-1), batch['labels'])
            val_loss_a += loss.item()

        for batch in val_loader_b:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs =  two_head_forward(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                head='b'
            )
            loss = loss_fn(outputs.logits.squeeze(-1), batch['labels'])
            val_loss_b += loss.item()

    avg_val_loss_a = val_loss_a / len(val_loader_a)
    avg_val_loss_b = val_loss_b / len(val_loader_b)

    print(f"Epoch {epoch+1} Complete:")
    print(f"  Head A - Train: {avg_loss_a:.4f}, Val: {avg_val_loss_a:.4f} ({count_a} batches)")
    print(f"  Head B - Train: {avg_loss_b:.4f}, Val: {avg_val_loss_b:.4f} ({count_b} batches)")

    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'base_model_state_dict': base_model.state_dict(),
        'head_a_state_dict': head_a.state_dict(),
        'head_b_state_dict': head_b.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss_a': avg_loss_a,
        'loss_b': avg_loss_b,
        'val_loss_a': avg_val_loss_a,
        'val_loss_b': avg_val_loss_b,
    }
    torch.save(checkpoint, os.path.join(OUTPUT_DIR_MIXED, f'checkpoint_epoch{epoch+1}.pt'))

print("\n" + "="*60)
print("✓ Mixed Training Complete!")
print(f"Models saved to: {OUTPUT_DIR_MIXED}")
print("="*60)


# (Option4) Token logic prediction

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import math

class MathDifficultyClassifier:
    def __init__(self, model_name: str, device: torch.device, num_classes: int = 4):
        """
        A difficulty classifier using letter labels instead of numbers.

        Args:
            model_name: HuggingFace model name
            device: "cuda" / "cpu"
            num_classes: Number of difficulty levels (2, 3, or 4)
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto",
        )

        if device is not None:
            self.model.to(device)

        self.num_classes = num_classes

        # Define letter labels based on num_classes
        if num_classes == 2:
            # Binary: Hard, Easy
            self.class_labels = ["h", "e"]  # hard, easy
            self.class_names = ["Hard", "Easy"]
        elif num_classes == 3:
            # Three levels: Hard, Medium, Easy
            self.class_labels = ["h", "m", "e"]  # hard, medium, easy
            self.class_names = ["Hard", "Medium", "Easy"]
        elif num_classes == 4:
            # Four levels: Very Hard, Hard, Medium, Easy
            self.class_labels = ["v", "h", "m", "e"]  # very hard, hard, medium, easy
            self.class_names = ["Very Hard", "Hard", "Medium", "Easy"]
        else:
            raise ValueError(f"num_classes must be 2, 3, or 4, got {num_classes}")

        # Pre-tokenize difficulty level labels with space prefix
        self.class_ids = []
        for label in self.class_labels:
            label_text = f" {label}"
            ids = self.tokenizer(label_text, add_special_tokens=False).input_ids
            if len(ids) == 0:
                raise ValueError(f"Tokenizer produced empty ids for label '{label}'")
            self.class_ids.append(ids)

        # Check if all labels tokenize to single tokens
        self.single_token_labels = all(len(ids) == 1 for ids in self.class_ids)
        if self.single_token_labels:
            self.class_token_ids = [ids[0] for ids in self.class_ids]

    def build_prompt(self, question: str) -> str:
        """Build the difficulty prediction prompt."""
        if self.num_classes == 2:
            instruction = "Is this math question as h (hard) or e (easy)? Answer with one character, h or e. "
        elif self.num_classes == 3:
            instruction = "Is this math question as h (hard), m (medium), or e (easy)? Answer with one character, h, m or e."
        elif self.num_classes == 4:
            instruction = "Is this math question as v (very hard), h (hard), m (medium), or e (easy). Answer with one character, v, h, m or e. "

        prompt = (
            f"Question: {question}\n\n"
            f"{instruction}\n\n"

        )
        return prompt

    def _forward(self, question: str, compute_gradients: bool = False):
            """
            Internal forward pass that can optionally compute gradients.
            """
            # 1. Prepare Prompt
            prompt = self.build_prompt(question)
            context_enc = self.tokenizer(prompt, add_special_tokens=True)
            context_ids = context_enc.input_ids

            device = self.model.device
            ctx_tensor = torch.tensor(context_ids, dtype=torch.long, device=device)

            # 2. Create sequences for each class
            sequences = []
            for class_token_ids in self.class_ids:
                label_tensor = torch.tensor(class_token_ids, dtype=torch.long, device=device)
                seq = torch.cat([ctx_tensor, label_tensor])
                sequences.append(seq)

            # 3. Batch and pad
            pad_val = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
            input_ids = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=pad_val)
            attention_mask = (input_ids != pad_val).long()

            # 4. Forward pass (with or without gradients)
            if compute_gradients:
                outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
            else:
                with torch.no_grad():
                    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

            logits = outputs.logits
            log_probs = F.log_softmax(logits, dim=-1)

            # 5. Extract log probabilities
            class_log_probs = []
            for batch_idx, label_ids in enumerate(self.class_ids):
                start_pos = len(context_ids)
                end_pos = start_pos + len(label_ids)
                total_logprob = 0.0

                for pos in range(start_pos, end_pos):
                    target_token_id = input_ids[batch_idx, pos].item()
                    token_logprob = log_probs[batch_idx, pos - 1, target_token_id]
                    if not compute_gradients:
                        token_logprob = token_logprob.item()
                    total_logprob += token_logprob

                class_log_probs.append(total_logprob)

            # 6. Convert to tensor
            if compute_gradients:
                class_log_probs = torch.stack(class_log_probs)
            else:
                class_log_probs = torch.tensor(class_log_probs, device=device)

            return class_log_probs

    def predict(self, question: str) -> dict:
        """
        Return difficulty prediction as probability distribution over classes.
        Uses no_grad for inference.
        """
        class_log_probs = self._forward(question, compute_gradients=False)
        class_probs = F.softmax(class_log_probs, dim=0)
        predicted_class = class_probs.argmax().item()

        return {
            'probabilities': class_probs,
            'predicted_class': predicted_class,
            'predicted_label': self.class_labels[predicted_class],
            'predicted_name': self.class_names[predicted_class],
            'log_probs': class_log_probs
        }

    def compute_loss(self, question: str, true_class: int) -> tuple:
        """
        Compute cross-entropy loss for training.
        Enables gradient computation.
        """
        class_log_probs = self._forward(question, compute_gradients=True)

        # Cross-entropy loss: -log P(true_class)
        loss = -class_log_probs[true_class]

        # Also return prediction for tracking accuracy
        with torch.no_grad():
            class_probs = F.softmax(class_log_probs, dim=0)
            pred_class = class_probs.argmax().item()

        return loss, pred_class

# Example usage:
"""
# Initialize classifier
classifier = MathDifficultyClassifier(
    model_name="Qwen/Qwen2.5-0.5B",
    device=torch.device("cuda"),
    num_classes=4
)

# Predict difficulty
question = "What is 2 + 2?"
result = classifier.predict(question)
print(f"Predicted difficulty level: {result['predicted_level']}")
print(f"Class probabilities: {result['probabilities']}")

# Compute loss for training
true_class = 0  # This is an easy question (class 0 = level 1)
loss, pred = classifier.compute_loss(question, true_class)
print(f"Loss: {loss:.4f}, Predicted: {pred}, True: {true_class}")
"""

In [None]:
# ==============================================================================
# Test Pre-Finetuned Classification Capacity (with model output)
# ==============================================================================
# Goal: Evaluate zero-shot performance and check if model follows instructions
# ==============================================================================

print("Testing Pre-Finetuned Model on Difficulty Classification")
print("="*60)

# Configuration
NUM_CLASSES = 3
TEST_SAMPLES = 1 # Number of samples to test

# Load dataset
print("\n✓ Loading dataset...")
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
    raw_questions = json.load(f)

# Create test dataset (use validation split)
random.seed(42)
random.shuffle(raw_questions)
split_idx = int(0.9 * len(raw_questions))
val_questions = raw_questions[split_idx:split_idx + TEST_SAMPLES]

test_dataset = DifficultyClassificationDataset(
    val_questions,
    tokenizer,
    num_classes=NUM_CLASSES,
    max_length=256
)

print(f"✓ Test dataset loaded: {len(test_dataset)} samples")

# Initialize classifier
classifier = MathDifficultyClassifier(
    model_name=MODEL_NAME,
    device=device,
    num_classes=NUM_CLASSES
)

print(f"✓ Classifier initialized: {NUM_CLASSES} classes")
print(f"✓ Class labels: {classifier.class_labels}")
print(f"✓ Class names: {classifier.class_names}")
print(f"✓ Single token labels: {classifier.single_token_labels}")
print("\n" + "="*60)

# Helper function to generate actual model output
def get_model_generation(question, max_new_tokens=10):
    """Generate text to see what model actually outputs"""
    prompt = classifier.build_prompt(question)
    print(f"prompt{prompt}")
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)

    with torch.no_grad():
        outputs = classifier.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # Greedy decoding
            pad_token_id=tokenizer.eos_token_id
        )

    # Decode only the generated part (skip prompt)
    generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return generated_text.strip()

# Evaluate on test samples
predictions = []
true_labels = []
correct = 0

print("\nEvaluating samples...")
print("="*60)

for idx in range(min(TEST_SAMPLES, len(test_dataset))):
    # Get sample
    sample = test_dataset.samples[idx]
    question = val_questions[idx]['question']
    true_class = sample['label']  # 0-indexed
    true_prob = sample['prob']  # Original probability

    # Predict using probability method
    result = classifier.predict(question)
    pred_class = result['predicted_class']
    pred_label = result['predicted_label']
    pred_name = result['predicted_name']

    # Generate actual model output
    generated_answer = get_model_generation(question)

    predictions.append(pred_class)
    true_labels.append(true_class)

    is_correct = (pred_class == true_class)
    correct += is_correct

    # Print samples in detail

    print(f"\nSample {idx + 1}:")
    print(f"  Question: {question}")
    print(f"  True: Class {true_class} ({classifier.class_names[true_class]}) | Prob: {true_prob:.3f}")
    print(f"  Pred: Class {pred_class} ({pred_name}) | Label: '{pred_label}' | Confidence: {result['probabilities'][pred_class].item():.3f}")
    print(f"  Model Output: '{generated_answer}'")
    print(f"  Full distribution: {dict(zip(classifier.class_names, [f'{p:.3f}' for p in result['probabilities'].tolist()]))}")
    print(f"  Result: {'✓ CORRECT' if is_correct else '✗ WRONG'}")



# Calculate metrics
accuracy = correct / len(predictions)

print("\n" + "="*60)
print("EVALUATION RESULTS")
print("="*60)
print(f"Total samples: {len(predictions)}")
print(f"Correct predictions: {correct}")
print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")

# Confusion analysis
print(f"\nClass Distribution:")
for class_idx in range(NUM_CLASSES):
    true_count = true_labels.count(class_idx)
    pred_count = predictions.count(class_idx)
    print(f"  {classifier.class_names[class_idx]} ({classifier.class_labels[class_idx]}): True={true_count}, Predicted={pred_count}")

# Calculate per-class accuracy
print(f"\nPer-Class Accuracy:")
for class_idx in range(NUM_CLASSES):
    true_positives = sum(1 for t, p in zip(true_labels, predictions) if t == class_idx and p == class_idx)
    total_true = true_labels.count(class_idx)
    if total_true > 0:
        class_acc = true_positives / total_true
        print(f"  {classifier.class_names[class_idx]}: {class_acc:.4f} ({true_positives}/{total_true})")
    else:
        print(f"  {classifier.class_names[class_idx]}: N/A (no samples)")

# Random baseline
random_accuracy = 1.0 / NUM_CLASSES
print(f"\nRandom baseline: {random_accuracy:.4f} ({random_accuracy*100:.2f}%)")
print(f"Model improvement: {(accuracy - random_accuracy)*100:.2f} percentage points")

print("\n" + "="*60)
print("✓ Pre-Finetuned Evaluation Complete!")
print("="*60)

In [None]:
# Generate actual model output
question = val_questions[5]['question']
# answer = val_questions[2]['answers'][0]
# prompt = (
#     f"Question: {question}\n"
#     f"Is this answer correct? Answer y(yes) or n(no)."
# )
n = 2
# instruction = "Is this question easy or difficult? Answer E(easy) or D(difficult)."
if n == 2:
    instruction = "Rate this math question as h (hard) or e (easy)."
elif n == 3:
    instruction = "Rate this math question as h (hard), m (medium), or e (easy)."
elif n == 4:
    instruction = "Rate this math question as v (very hard), h (hard), m (medium), or e (easy)."
prompt = (
        f"Question: {question}\n\n"
        f"{instruction}\n\n"
)
# Helper function to generate actual model output
def get_model_result(prompt, max_new_tokens=10):
    """Generate text to see what model actually outputs"""
    inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=True).to(device)

    with torch.no_grad():
        outputs = classifier.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # Greedy decoding
            pad_token_id=tokenizer.eos_token_id
        )

    # Decode only the generated part (skip prompt)
    generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    return generated_text.strip()
generated_answer = get_model_result(prompt)
print(f"question: {question}, answer: {answer}, model_response: {generated_answer}")

In [None]:
# ==============================================================================
# Training Loop: Difficulty Classification with Letter Labels
# ==============================================================================

print("\nStarting training...")
print("="*60)

# Training configuration
EPOCHS = 10
LEARNING_RATE = 1e-4
BATCH_SIZE = 1  # Process one question at a time due to special forward pass
DEBUG_SAMPLE_SIZE = None  # Use small subset for quick testing

# Load and prepare dataset
print("\n✓ Loading dataset...")
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
    raw_questions = json.load(f)

if DEBUG_SAMPLE_SIZE:
    raw_questions = raw_questions[:DEBUG_SAMPLE_SIZE]

random.seed(42)
random.shuffle(raw_questions)
split_idx = int(0.9 * len(raw_questions))
train_questions = raw_questions[:split_idx]
val_questions = raw_questions[split_idx:]

train_dataset = DifficultyClassificationDataset(train_questions, tokenizer, num_classes=NUM_CLASSES, max_length=256)
val_dataset = DifficultyClassificationDataset(val_questions, tokenizer, num_classes=NUM_CLASSES, max_length=256)

print(f"✓ Dataset loaded: {len(train_dataset)} train, {len(val_dataset)} val samples")

# Initialize classifier and optimizer
classifier = MathDifficultyClassifier(MODEL_NAME, device, num_classes=NUM_CLASSES)
optimizer = torch.optim.AdamW(classifier.model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler
total_steps = len(train_dataset) * EPOCHS
warmup_steps = int(0.1 * total_steps)
scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=0.1,
    end_factor=1.0,
    total_iters=warmup_steps
)

print(f"\n✓ Training setup:")
print(f"  Model: {MODEL_NAME}")
print(f"  Classes: {NUM_CLASSES} ({', '.join(classifier.class_names)})")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Total steps: {total_steps}")

# Training loop
best_val_acc = 0.0
train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(EPOCHS):
    # ==================== Training ====================
    classifier.model.train()
    epoch_loss = 0.0
    train_correct = 0

    progress_bar = tqdm(range(len(train_dataset)), desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")

    for idx in progress_bar:
        # Get sample
        sample = train_dataset.samples[idx]
        question = train_questions[idx]['question']
        true_class = sample['label']

        # Compute loss
        loss, pred_class = classifier.compute_loss(question, true_class)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx < warmup_steps:
            scheduler.step()

        # Track metrics
        epoch_loss += loss.item()
        train_correct += (pred_class == true_class)

        # Update progress bar
        if (idx + 1) % 10 == 0:
            progress_bar.set_postfix({
                'loss': f'{epoch_loss/(idx+1):.4f}',
                'acc': f'{train_correct/(idx+1):.4f}'
            })

    avg_train_loss = epoch_loss / len(train_dataset)
    train_acc = train_correct / len(train_dataset)
    train_losses.append(avg_train_loss)

    # ==================== Validation ====================
    classifier.model.eval()
    val_loss = 0.0
    val_correct = 0

    with torch.no_grad():
        for idx in tqdm(range(len(val_dataset)), desc=f"Epoch {epoch+1}/{EPOCHS} [Val]  ", leave=False):
            sample = val_dataset.samples[idx]
            question = val_questions[idx]['question']
            true_class = sample['label']

            # Compute loss
            loss, pred_class = classifier.compute_loss(question, true_class)
            val_loss += loss.item()
            val_correct += (pred_class == true_class)

    avg_val_loss = val_loss / len(val_dataset)
    val_acc = val_correct / len(val_dataset)
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_acc)

    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{EPOCHS} Summary:")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_acc:.4f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Save model checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': classifier.model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
        }, 'best_difficulty_classifier.pt')
        print(f"  ✓ New best validation accuracy: {best_val_acc:.4f} - Model saved!")

    print("-"*60)

# Final results
print("\n" + "="*60)
print("TRAINING COMPLETE!")
print("="*60)
print(f"Best Validation Accuracy: {best_val_acc:.4f}")
print(f"Random Baseline: {1.0/NUM_CLASSES:.4f}")
print(f"Improvement: {(best_val_acc - 1.0/NUM_CLASSES)*100:.2f} percentage points")
print("="*60)
