In [1]:
import math
import random

import torch
from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup

## Verifier Class

In [2]:
class MathAnswerVerifier:
    def __init__(self, model_name: str, device: torch.device,
                 label_yes: str = " y", label_no: str = " n"):
        """
        A simple verifier built on top of a decoder-only LLM
        (e.g., Qwen / GPT-2 style AutoModelForCausalLM).

        Given (question, answer), it estimates:
            P(correct | question, answer) in (0, 1)

        Args:
            model_name: HuggingFace model name, e.g. "Qwen/Qwen2.5-0.5B" or "gpt2".
            device: "cuda" / "cpu".
            label_yes: text label representing "correct" (here: ' y').
            label_no: text label representing "incorrect" (here: ' n').
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto",
        )

        if device is not None:
            self.model.to(device)

        # Store labels (for training you will reuse them)
        self.label_yes = label_yes
        self.label_no = label_no

        # Pre-tokenize yes/no label sequences
        # IMPORTANT: We tokenize them separately to append them later by ID, 
        # avoiding string concatenation artifacts.
        self.yes_ids = self.tokenizer(label_yes, add_special_tokens=False).input_ids
        self.no_ids = self.tokenizer(label_no, add_special_tokens=False).input_ids

        if len(self.yes_ids) == 0 or len(self.no_ids) == 0:
            raise ValueError("Tokenizer produced empty ids for yes/no labels.")

    def build_prompt(self, question: str, answer: str) -> str:
        """
        Build the verifier prompt.
        IMPORTANT: This should be used consistently in both inference and training.
        """
        prompt = (
            f"Question: {question}\n"
            f"Answer: {answer}\n\n"
            f"Is this answer correct? Answer y(Yes) or n(No)."
        )
        return prompt

    @torch.no_grad()
    def score(self, question: str, answer: str) -> float:
        """
        Return P(correct | question, answer) in (0, 1).
        
        Optimized implementation:
        1. Uses Batching (computes Yes and No in a single forward pass).
        2. Handles Tokenization correctly by concatenating IDs instead of strings.
        """
        # 1. Prepare Prompt IDs
        prompt = self.build_prompt(question, answer)
        # Add BOS token if the model expects it, but do not truncate here generally
        context_enc = self.tokenizer(prompt, add_special_tokens=True)
        context_ids = context_enc.input_ids
        
        # 2. Prepare Sequences (Context + Label) via Tensor Concatenation
        # We convert to tensor immediately to use efficient concatenation
        device = self.model.device
        ctx_tensor = torch.tensor(context_ids, dtype=torch.long, device=device)
        yes_tensor = torch.tensor(self.yes_ids, dtype=torch.long, device=device)
        no_tensor = torch.tensor(self.no_ids, dtype=torch.long, device=device)

        # Create two sequences: [Context, label_yes] and [Context, label_no]
        seq_yes = torch.cat([ctx_tensor, yes_tensor])
        seq_no = torch.cat([ctx_tensor, no_tensor])

        # 3. Batching and Padding
        # Pad sequences to handle cases where 'label_yes' and 'label_no' differ in length
        pad_val = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
        input_ids = torch.nn.utils.rnn.pad_sequence([seq_yes, seq_no], batch_first=True, padding_value=pad_val)
        attention_mask = (input_ids != pad_val).long()

        # 4. Forward Pass (Batch Size = 2)
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # Shape: [2, seq_len, vocab_size]
        # Use log_softmax for numerical stability
        log_probs = F.log_softmax(logits, dim=-1)

        # 5. Extract Log Probabilities for Labels
        # Helper to sum log-probs of the label tokens given context
        def get_label_logprob(batch_idx, label_len):
            # The label starts immediately after the context
            start_pos = len(context_ids)
            end_pos = start_pos + label_len
            total_logprob = 0.0
            for _, pos in enumerate(range(start_pos, end_pos)):
                target_token_id = input_ids[batch_idx, pos].item()
                # To predict token at `pos`, we look at logits at `pos - 1`
                token_logprob = log_probs[batch_idx, pos - 1, target_token_id].item()
                total_logprob += token_logprob
            return total_logprob

        logp_yes = get_label_logprob(0, len(self.yes_ids))
        logp_no = get_label_logprob(1, len(self.no_ids))

        # 6. Normalize: P(Yes) = exp(Yes) / (exp(Yes) + exp(No))
        max_logp = max(logp_yes, logp_no)
        p_yes_score = math.exp(logp_yes - max_logp)
        p_no_score = math.exp(logp_no - max_logp)
        prob_correct = p_yes_score / (p_yes_score + p_no_score)
        return float(prob_correct)


## (Optional) Test Verifier Inference (Score Correctness)

### Copy Answer Generation Code from experiment.ipynb

In [None]:
from datasets import load_dataset
import re

# Load the tokenizer and the model
model_name = "Qwen/Qwen3-0.6B"

print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)

# Load GSM8K dataset
print("Loading GSM8K dataset...")
ds = load_dataset("openai/gsm8k", "main")

print(f"Dataset loaded: {len(ds['test'])} test examples, {len(ds['train'])} train examples")
print(f"Model loaded on device: {model.device}")

# Check device information
print("="*60)
print("DEVICE INFORMATION")
print("="*60)

# Check if CUDA is available
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    print(f"Current GPU device: {torch.cuda.current_device()}")
    print(f"GPU device name: {torch.cuda.get_device_name(0)}")
    
    # Memory information
    print(f"\nGPU Memory:")
    print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"  Reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
else:
    print("Running on CPU")

# Check model device
print(f"\nModel device: {model.device}")
print(f"Model dtype: {model.dtype}")

# Check which device each parameter is on (for distributed models)
devices = set()
for name, param in model.named_parameters():
    devices.add(str(param.device))

if len(devices) > 1:
    print(f"\nModel is distributed across multiple devices: {devices}")
else:
    print(f"\nAll model parameters are on: {list(devices)[0]}")

print("="*60)


# Helper Functions for Answer Extraction, Verification, and Generation
def extract_answer(text):
    """
    Extract the numerical answer from the text.
    GSM8K answers typically end with #### followed by the number.
    """
    # Try to find the answer after ####
    match = re.search(r'####\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)', text)
    if match:
        # Remove commas from the number
        return match.group(1).replace(',', '')
    
    # Fallback: try to find the last number in the text
    numbers = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text)
    if numbers:
        return numbers[-1].replace(',', '')
    
    return None

def check_answer_correct(generated_answer, reference_answer):
    """
    Check if the generated answer matches the reference answer.
    """
    gen = extract_answer(generated_answer)
    ref = extract_answer(reference_answer)
    
    if gen is None or ref is None:
        return False
    
    try:
        # Compare as floats to handle different formats
        return abs(float(gen) - float(ref)) < 0.01
    except:
        return gen == ref

def generate_answers(question, num_answers=10, max_new_tokens=512, temperature=0.7):
    """
    Generate multiple different answers to the same question using Qwen chat template.
    """
    # Format the prompt with explicit instruction to use GSM8K format
    prompt_text = f"""Question: {question}
Answer: Let's solve this step by step concisely. End your answer with #### followed by the final numerical answer."""
    
    # Use Qwen chat template format
    messages = [
        {"role": "user", "content": prompt_text}
    ]
    
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False  # Disable thinking mode for faster generation
    )
    
    # Tokenize
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    # Generate multiple answers
    answers = []
    for i in range(num_answers):
        print(f"\n{'='*60}")
        print(f"Generating answer {i+1}/{num_answers}...")
        print('='*60)
        
        with torch.no_grad():
            generated_ids = model.generate(
                **model_inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode only the generated part (not the input prompt)
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
        generated_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
        
        answers.append(generated_text)
        
        # Print the result after each trial
        print(f"\nGenerated Answer {i+1}:")
        print('-'*60)
        print(generated_text )
        print('-'*60)
        extracted = extract_answer(generated_text)
        print(f"Extracted Answer: {extracted}")
    
    return answers

# Test with the first question from the test set
test_example = ds['test'][0]
print(f"Question: {test_example['question']}")
print(f"\nReference Answer: {test_example['answer']}")
print(f"Reference Final Answer: {extract_answer(test_example['answer'])}")
print("\n" + "="*80)
# Generate answers for the first question
print("Generating different answers...")
num_answers = 3
generated_answers = generate_answers(test_example['question'], num_answers=num_answers)

# Check which answers are correct
print("\n" + "="*80)
print("VERIFICATION RESULTS:")
print("="*80)

correct_answers = []
for i, answer in enumerate(generated_answers):
    extracted = extract_answer(answer)
    is_correct = check_answer_correct(answer, test_example['answer'])
    
    print(f"\nAnswer {i+1}:")
    print(f"Extracted value: {extracted}")
    print(f"Correct: {is_correct}")
    print(f"Response preview: {answer[:200]}...")
    
    if is_correct:
        correct_answers.append(i)

print("\n" + "="*80)
print(f"Summary: {len(correct_answers)}/{num_answers} answers were correct")
print(f"Correct answer indices: {correct_answers}")

### Score Answers using New Verifier

In [None]:
print("Loading verifier model...")
verifier_model_name = "Qwen/Qwen3-0.6B"
device = model.device  # Use the same device as the main model
verifier = MathAnswerVerifier(verifier_model_name, device)

print("Scoring all generated answers...")
scores = []
for i, answer in enumerate(generated_answers):
    score = verifier.score(test_example['question'], answer)
    scores.append(score)
    is_correct = check_answer_correct(answer, test_example['answer'])
    print(f"Answer {i+1}: Correctness Score = {score:.4f}, Correct = {is_correct}")

# Find the best answer according to the verifier
best_idx = scores.index(max(scores))
print(f"\n" + "="*80)
print(f"Best answer according to verifier: Answer {best_idx + 1}")
print(f"Is the best answer correct? {check_answer_correct(generated_answers[best_idx], test_example['answer'])}")
print(f"Best answer: {generated_answers[best_idx][:300]}...")

## Finetune

### Dataset Class for Verifier Finetuning 

In [3]:
class VerifierDataset(Dataset):
    def __init__(self, examples, verifier: MathAnswerVerifier, max_length: int = 512):
        self.examples = examples
        self.verifier = verifier
        self.tokenizer = verifier.tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        q = ex["question"]
        a = ex["answer"]
        y = ex["label"]  # 1 or 0

        prompt = self.verifier.build_prompt(q, a)

        # Choose "y" or "n" according to label
        target_text = self.verifier.label_yes if y == 1 else self.verifier.label_no

        # 1. Tokenize Context (add special tokens like BOS)
        context_enc = self.tokenizer(
            prompt, 
            add_special_tokens=True, 
            return_attention_mask=False
        )
        context_ids = context_enc.input_ids

        # 2. Tokenize Label (NO special tokens)
        target_enc = self.tokenizer(
            target_text, 
            add_special_tokens=False, 
            return_attention_mask=False
        )
        target_ids = target_enc.input_ids
        target_ids += [self.tokenizer.eos_token_id]

        # 3. Concatenate
        full_ids = context_ids + target_ids

        # 4. Truncate
        if len(full_ids) > self.max_length:
            full_ids = full_ids[-self.max_length:]
            new_context_len = len(full_ids) - len(target_ids)
        else:
            new_context_len = len(context_ids)

        input_ids = torch.tensor(full_ids, dtype=torch.long)
        attention_mask = torch.ones_like(input_ids)
        labels = input_ids.clone()
        if new_context_len > 0:
            labels[:new_context_len] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


### Function for Finetuning

In [None]:
def finetune_verifier(verifier: MathAnswerVerifier, data, epochs=1, batch_size=4, lr=1e-5, warmup_ratio=0.1):
    """
    Finetune the verifier with a Learning Rate Scheduler.
    """
    dataset = VerifierDataset(data, verifier)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda batch: {
        "input_ids": torch.nn.utils.rnn.pad_sequence(
            [b["input_ids"] for b in batch],
            batch_first=True,
            padding_value=verifier.tokenizer.pad_token_id or verifier.tokenizer.eos_token_id
        ),
        "attention_mask": torch.nn.utils.rnn.pad_sequence(
            [b["attention_mask"] for b in batch],
            batch_first=True,
            padding_value=0
        ),
        "labels": torch.nn.utils.rnn.pad_sequence(
            [b["labels"] for b in batch],
            batch_first=True,
            padding_value=-100
        ),
    })

    model = verifier.model
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)

    # Scheduler Setup
    # Calculate total training steps
    total_steps = len(dataloader) * epochs
    # Calculate warmup steps (usually 3% to 10% of total steps)
    num_warmup_steps = int(total_steps * warmup_ratio)
    # Create the scheduler
    # It linearly increases LR from 0 to `lr` during warmup, 
    # then linearly decreases it to 0 by the end of training.
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=num_warmup_steps, 
        num_training_steps=total_steps
    )
    print(f"Training for {total_steps} steps with {num_warmup_steps} warmup steps.")

    for epoch in range(epochs):
        total_loss = 0.0
        for step, batch in enumerate(dataloader):
            batch = {k: v.to(model.device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step() 

            total_loss += loss.item()
            
            current_lr = scheduler.get_last_lr()[0]
            if (step + 1) % 50 == 0:
                print(f"Epoch {epoch+1}, step {step+1}, loss = {total_loss / (step+1):.4f}, lr = {current_lr:.2e}")

        print(f"Epoch {epoch+1} finished, avg loss = {total_loss / len(dataloader):.4f}")

    model.eval()
    return verifier

### Finetuning Example

In [None]:
# 1. Setup Device (Prioritize GPU)
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

print(f"Using device: {device}")

# 2. Initialize Verifier
# use gpt2 as a demo because it's small and fast.
model_name = "gpt2"

print(f"Loading model {model_name}...")
verifier = MathAnswerVerifier(
    model_name=model_name, 
    device=device,
    label_yes=" y",
    label_no=" n"
)

# Some Decoder-only models (e.g., GPT2) don't have a pad_token by default; manually set it to eos_token
if verifier.tokenizer.pad_token is None:
    verifier.tokenizer.pad_token = verifier.tokenizer.eos_token
    verifier.model.config.pad_token_id = verifier.model.config.eos_token_id

# 3. Construct a Tiny Dataset (Training Data)
# Includes simple math problems, both correct (label=1) and incorrect (label=0)
train_data = [
    {"question": "1 + 1 = ?", "answer": "2", "label": 1},
    {"question": "1 + 1 = ?", "answer": "3", "label": 0},
    {"question": "2 * 3 = ?", "answer": "6", "label": 1},
    {"question": "2 * 3 = ?", "answer": "8", "label": 0},
    {"question": "10 - 5 = ?", "answer": "5", "label": 1},
    {"question": "10 - 5 = ?", "answer": "4", "label": 0},
    {"question": "3 + 3 = ?", "answer": "6", "label": 1},
    {"question": "3 + 3 = ?", "answer": "9", "label": 0},
]

# Duplicate data to simulate a slightly larger epoch
train_data = train_data * 4  # Total 32 samples
random.shuffle(train_data)

# 4. Test Case Definition
test_q = "2 + 2 = ?"
test_a_correct = "4"
test_a_wrong = "5"

print("\n" + "="*80)
print("PRE-TRAINING TEST")
print("="*80)

score_correct = verifier.score(test_q, test_a_correct)
score_wrong = verifier.score(test_q, test_a_wrong)

print(f"Q: {test_q}")
print(f"A: {test_a_correct} (Correct) -> Score: {score_correct:.4f}")
print(f"A: {test_a_wrong} (Wrong)   -> Score: {score_wrong:.4f}")

# At this point, the model likely hasn't learned to output 'y' or 'n', 
# so scores might be random (around 0.5).

# 5. Start Fine-tuning
print("\n" + "="*80)
print("START FINE-TUNING")
print("="*80)

# Train for 5 epochs with a slightly higher LR for demonstration purposes
verifier = finetune_verifier(
    verifier, 
    train_data, 
    epochs=5, 
    batch_size=4, 
    lr=5e-5,
    warmup_ratio=0.1,
)

# 6. Post-training Test
print("\n" + "="*80)
print("POST-TRAINING TEST")
print("="*80)

score_correct_post = verifier.score(test_q, test_a_correct)
score_wrong_post = verifier.score(test_q, test_a_wrong)

print(f"Q: {test_q}")
print(f"A: {test_a_correct} (Correct) -> Score: {score_correct_post:.4f}")
print(f"A: {test_a_wrong} (Wrong)   -> Score: {score_wrong_post:.4f}")

# 7. Test Generalization (Unseen Data)
print("\nTesting generalization (Unseen Data):")
unseen_q = "5 + 5 = ?"
unseen_a_corr = "10"
unseen_a_wrong = "12"

s_corr = verifier.score(unseen_q, unseen_a_corr)
s_wrong = verifier.score(unseen_q, unseen_a_wrong)
print(f"Q: {unseen_q}, A: {unseen_a_corr} -> {s_corr:.4f}")
print(f"Q: {unseen_q}, A: {unseen_a_wrong} -> {s_wrong:.4f}")

# 8. Save Model (Optional)
# verifier.model.save_pretrained("./my_math_verifier")
# verifier.tokenizer.save_pretrained("./my_math_verifier")

Using device: cuda
Loading model gpt2...


2025-11-23 21:47:19.653948: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-23 21:47:19.682239: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-23 21:47:20.576605: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
`torch_dtype` is deprecated! Use `dtype` instead!
`loss_type=None` was set in the config bu


PRE-TRAINING TEST
Q: 2 + 2 = ?
A: 4 (Correct) -> Score: 0.4586
A: 5 (Wrong)   -> Score: 0.4260

START FINE-TUNING
Training for 40 steps with 4 warmup steps.
Epoch 1 finished, avg loss = 3.7910
Epoch 2 finished, avg loss = 0.9050
Epoch 3 finished, avg loss = 0.8529
Epoch 4 finished, avg loss = 0.6503
Epoch 5 finished, avg loss = 0.5194

POST-TRAINING TEST
Q: 2 + 2 = ?
A: 4 (Correct) -> Score: 0.6832
A: 5 (Wrong)   -> Score: 0.6585

Testing generalization (Unseen Data):
Q: 5 + 5 = ?, A: 10 -> 0.6299
Q: 5 + 5 = ?, A: 12 -> 0.6061
