In [8]:
import math
import random

import torch
from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup

## Verifier Class

In [2]:
class MathAnswerVerifier:
    def __init__(self, model_name: str, device: torch.device,
                 label_yes: str = " y", label_no: str = " n"):
        """
        A simple verifier built on top of a decoder-only LLM
        (e.g., Qwen / GPT-2 style AutoModelForCausalLM).

        Given (question, answer), it estimates:
            P(correct | question, answer) in (0, 1)

        Args:
            model_name: HuggingFace model name, e.g. "Qwen/Qwen2.5-0.5B" or "gpt2".
            device: "cuda" / "cpu".
            label_yes: text label representing "correct" (here: ' y').
            label_no: text label representing "incorrect" (here: ' n').
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            dtype="auto",
        )

        if device is not None:
            self.model.to(device)

        # Store labels (for training you will reuse them)
        self.label_yes = label_yes
        self.label_no = label_no

        # Pre-tokenize yes/no label sequences
        # IMPORTANT: We tokenize them separately to append them later by ID,
        # avoiding string concatenation artifacts.
        self.yes_ids = self.tokenizer(label_yes, add_special_tokens=False).input_ids
        self.no_ids = self.tokenizer(label_no, add_special_tokens=False).input_ids

        if len(self.yes_ids) == 0 or len(self.no_ids) == 0:
            raise ValueError("Tokenizer produced empty ids for yes/no labels.")

    def build_prompt(self, question: str, answer: str) -> str:
        """
        Build the verifier prompt.
        IMPORTANT: This should be used consistently in both inference and training.
        """
        prompt = (
            f"Question: {question}\n"
            f"Answer: {answer}\n\n"
            f"Is this answer correct? Answer y(Yes) or n(No)."
        )
        return prompt

    @torch.no_grad()
    def score(self, question: str, answer: str) -> float:
        """
        Return P(correct | question, answer) in (0, 1).

        Optimized implementation:
        1. Uses Batching (computes Yes and No in a single forward pass).
        2. Handles Tokenization correctly by concatenating IDs instead of strings.
        """
        # 1. Prepare Prompt IDs
        prompt = self.build_prompt(question, answer)
        # Add BOS token if the model expects it, but do not truncate here generally
        context_enc = self.tokenizer(prompt, add_special_tokens=True)
        context_ids = context_enc.input_ids

        # 2. Prepare Sequences (Context + Label) via Tensor Concatenation
        # We convert to tensor immediately to use efficient concatenation
        device = self.model.device
        ctx_tensor = torch.tensor(context_ids, dtype=torch.long, device=device)
        yes_tensor = torch.tensor(self.yes_ids, dtype=torch.long, device=device)
        no_tensor = torch.tensor(self.no_ids, dtype=torch.long, device=device)

        # Create two sequences: [Context, label_yes] and [Context, label_no]
        seq_yes = torch.cat([ctx_tensor, yes_tensor])
        seq_no = torch.cat([ctx_tensor, no_tensor])

        # 3. Batching and Padding
        # Pad sequences to handle cases where 'label_yes' and 'label_no' differ in length
        pad_val = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
        input_ids = torch.nn.utils.rnn.pad_sequence([seq_yes, seq_no], batch_first=True, padding_value=pad_val)
        attention_mask = (input_ids != pad_val).long()

        # 4. Forward Pass (Batch Size = 2)
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # Shape: [2, seq_len, vocab_size]
        # Use log_softmax for numerical stability
        log_probs = F.log_softmax(logits, dim=-1)

        # 5. Extract Log Probabilities for Labels
        # Helper to sum log-probs of the label tokens given context
        def get_label_logprob(batch_idx, label_len):
            # The label starts immediately after the context
            start_pos = len(context_ids)
            end_pos = start_pos + label_len
            total_logprob = 0.0
            for _, pos in enumerate(range(start_pos, end_pos)):
                target_token_id = input_ids[batch_idx, pos].item()
                # To predict token at `pos`, we look at logits at `pos - 1`
                token_logprob = log_probs[batch_idx, pos - 1, target_token_id].item()
                total_logprob += token_logprob
            return total_logprob

        logp_yes = get_label_logprob(0, len(self.yes_ids))
        logp_no = get_label_logprob(1, len(self.no_ids))

        # 6. Normalize: P(Yes) = exp(Yes) / (exp(Yes) + exp(No))
        max_logp = max(logp_yes, logp_no)
        p_yes_score = math.exp(logp_yes - max_logp)
        p_no_score = math.exp(logp_no - max_logp)
        prob_correct = p_yes_score / (p_yes_score + p_no_score)
        return float(prob_correct)


## (Optional) Test Verifier Inference (Score Correctness)

### Copy Answer Generation Code from experiment.ipynb

In [3]:
from datasets import load_dataset
import re

# Load the tokenizer and the model
model_name = "Qwen/Qwen3-0.6B"

print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    dtype="auto",
    device_map="auto"
)

# Load GSM8K dataset
print("Loading GSM8K dataset...")
ds = load_dataset("openai/gsm8k", "main")

print(f"Dataset loaded: {len(ds['test'])} test examples, {len(ds['train'])} train examples")
print(f"Model loaded on device: {model.device}")

# Check device information
print("="*60)
print("DEVICE INFORMATION")
print("="*60)

# Check if CUDA is available
device = torch.device("cpu")

print(f"PyTorch version: {torch.__version__}")

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Backend: CUDA (NVIDIA GPU)")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    print(f"Current GPU device: {torch.cuda.current_device()}")
    print(f"GPU device name: {torch.cuda.get_device_name(0)}")

    print(f"\nGPU Memory:")
    print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"  Reserved:  {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

# 2. 如果没有 CUDA，再检查 MPS（Apple Silicon GPU）
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Backend: MPS (Apple Silicon GPU)")
    print("MPS is available and will be used as the device.")

else:
    print("No GPU backend available. Running on CPU.")

print(f"\nUsing device: {device}")

# Move model to the selected device
model.to(device)
# Check model device
print(f"\nModel device: {model.device}")

# Check which device each parameter is on (for distributed models)
devices = set()
for name, param in model.named_parameters():
    devices.add(str(param.device))

if len(devices) > 1:
    print(f"\nModel is distributed across multiple devices: {devices}")
else:
    print(f"\nAll model parameters are on: {list(devices)[0]}")

print("="*60)


# Helper Functions for Answer Extraction, Verification, and Generation
def extract_answer(text):
    """
    Extract the numerical answer from the text.
    GSM8K answers typically end with #### followed by the number.
    """
    # Try to find the answer after ####
    match = re.search(r'####\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)', text)
    if match:
        # Remove commas from the number
        return match.group(1).replace(',', '')

    # Fallback: try to find the last number in the text
    numbers = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text)
    if numbers:
        return numbers[-1].replace(',', '')

    return None

def check_answer_correct(generated_answer, reference_answer):
    """
    Check if the generated answer matches the reference answer.
    """
    gen = extract_answer(generated_answer)
    ref = extract_answer(reference_answer)

    if gen is None or ref is None:
        return False

    try:
        # Compare as floats to handle different formats
        return abs(float(gen) - float(ref)) < 0.01
    except:
        return gen == ref

def generate_answers(question, num_answers=10, max_new_tokens=512, temperature=0.7):
    """
    Generate multiple different answers to the same question using Qwen chat template.
    """
    # Format the prompt with explicit instruction to use GSM8K format
    prompt_text = f"""Question: {question}
Answer: Let's solve this step by step concisely. End your answer with #### followed by the final numerical answer."""

    # Use Qwen chat template format
    messages = [
        {"role": "user", "content": prompt_text}
    ]

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False  # Disable thinking mode for faster generation
    )

    # Tokenize
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    # Generate multiple answers
    answers = []
    for i in range(num_answers):
        print(f"\n{'='*60}")
        print(f"Generating answer {i+1}/{num_answers}...")
        print('='*60)

        with torch.no_grad():
            generated_ids = model.generate(
                **model_inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id
            )

        # Decode only the generated part (not the input prompt)
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
        generated_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip()

        answers.append(generated_text)

        # Print the result after each trial
        print(f"\nGenerated Answer {i+1}:")
        print('-'*60)
        print(generated_text )
        print('-'*60)
        extracted = extract_answer(generated_text)
        print(f"Extracted Answer: {extracted}")

    return answers

# Test with the first question from the test set
test_example = ds['test'][0]
print(f"Question: {test_example['question']}")
print(f"\nReference Answer: {test_example['answer']}")
print(f"Reference Final Answer: {extract_answer(test_example['answer'])}")
print("\n" + "="*80)
# Generate answers for the first question
print("Generating different answers...")
num_answers = 3
generated_answers = generate_answers(test_example['question'], num_answers=num_answers)

# Check which answers are correct
print("\n" + "="*80)
print("VERIFICATION RESULTS:")
print("="*80)

correct_answers = []
for i, answer in enumerate(generated_answers):
    extracted = extract_answer(answer)
    is_correct = check_answer_correct(answer, test_example['answer'])

    print(f"\nAnswer {i+1}:")
    print(f"Extracted value: {extracted}")
    print(f"Correct: {is_correct}")
    print(f"Response preview: {answer[:200]}...")

    if is_correct:
        correct_answers.append(i)

print("\n" + "="*80)
print(f"Summary: {len(correct_answers)}/{num_answers} answers were correct")
print(f"Correct answer indices: {correct_answers}")

Loading tokenizer and model...
Loading GSM8K dataset...
Dataset loaded: 1319 test examples, 7473 train examples
Model loaded on device: mps:0
DEVICE INFORMATION
PyTorch version: 2.9.0
Backend: MPS (Apple Silicon GPU)
MPS is available and will be used as the device.

Using device: mps

Model device: mps:0

All model parameters are on: mps:0
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?

Reference Answer: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.
She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.
#### 18
Reference Final Answer: 18

Generating different answers...

Generating answer 1/3...

Generated Answer 1:
------------------------------------------------------------
To find out how much Janet makes at the farmers' m

### Score Answers using New Verifier

In [4]:
print("Loading verifier model...")
verifier_model_name = "Qwen/Qwen3-0.6B"
device = model.device  # Use the same device as the main model
verifier = MathAnswerVerifier(verifier_model_name, device)

print("Scoring all generated answers...")
scores = []
for i, answer in enumerate(generated_answers):
    score = verifier.score(test_example['question'], answer)
    scores.append(score)
    is_correct = check_answer_correct(answer, test_example['answer'])
    print(f"Answer {i+1}: Correctness Score = {score:.4f}, Correct = {is_correct}")

# Find the best answer according to the verifier
best_idx = scores.index(max(scores))
print(f"\n" + "="*80)
print(f"Best answer according to verifier: Answer {best_idx + 1}")
print(f"Is the best answer correct? {check_answer_correct(generated_answers[best_idx], test_example['answer'])}")
print(f"Best answer: {generated_answers[best_idx][:300]}...")

Loading verifier model...
Scoring all generated answers...
Answer 1: Correctness Score = 0.1561, Correct = True
Answer 2: Correctness Score = 0.4844, Correct = False
Answer 3: Correctness Score = 0.3630, Correct = False

Best answer according to verifier: Answer 2
Is the best answer correct? False
Best answer: To find how much Janet makes at the farmers' market, we need to calculate:

1. **Eggs laid per day**: 16 eggs  
2. **Eggs eaten**: 3 for breakfast + 3 for baking = 6 eggs  
3. **Eggs sold**: 16 - 6 = 10 eggs  
4. **Eggs sold per day**: 10 eggs × $2 = $20

### Final Answer:
#### $20...


## Finetune

### Dataset Class for Verifier Finetuning 

In [None]:
class VerifierDataset(Dataset):
    def __init__(self, raw_data, verifier, max_length: int = 512):
        """
        Args:
            raw_data: List of dicts. Each dict contains 'question', 'reference_answer',
                      'answers' (list), and 'answer_labels' (list).
            verifier: The verifier model wrapper.
            max_length: Token limit.
        """
        self.verifier = verifier
        self.tokenizer = verifier.tokenizer
        self.max_length = max_length
        self.samples = []

        # Flatten the dataset with Data Augmentation
        for entry in raw_data:
            question = entry["question"]

            # 1. Get original answers and labels
            answers = entry["answers"]
            labels = entry["answer_labels"]

            # 2. Get the reference answer
            ref_answer = entry["reference_answer"]

            # 3. Build the complete list for training
            # If a reference answer exists, append it to the end and add the corresponding label 1
            if ref_answer is not None:
                train_answers = answers + [ref_answer]
                train_labels = labels + [1]
            else:
                print("Warning: No reference answer provided for question:", question)
                train_answers = answers
                train_labels = labels

            # 4. Flatten the dataset
            # Split the structure of 1 Question -> N+1 Answers into N+1 independent samples
            for ans, lbl in zip(train_answers, train_labels):
                self.samples.append({
                    "question": question,
                    "answer": ans,
                    "label": lbl
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ex = self.samples[idx]
        q = ex["question"]
        a = ex["answer"]
        y = ex["label"]  # 1 or 0

        # Build prompt
        prompt = self.verifier.build_prompt(q, a)

        # Label mapping
        target_text = self.verifier.label_yes if y == 1 else self.verifier.label_no

        # 1. Tokenize Context
        context_enc = self.tokenizer(
            prompt,
            add_special_tokens=True,
            return_attention_mask=False
        )
        context_ids = context_enc.input_ids

        # 2. Tokenize Label
        target_enc = self.tokenizer(
            target_text,
            add_special_tokens=False,
            return_attention_mask=False
        )
        target_ids = target_enc.input_ids
        target_ids += [self.tokenizer.eos_token_id]

        # 3. Concatenate
        full_ids = context_ids + target_ids

        # 4. Truncate
        if len(full_ids) > self.max_length:
            full_ids = full_ids[-self.max_length:]
            new_context_len = len(full_ids) - len(target_ids)
        else:
            new_context_len = len(context_ids)

        input_ids = torch.tensor(full_ids, dtype=torch.long)
        attention_mask = torch.ones_like(input_ids)

        labels = input_ids.clone()
        if new_context_len > 0:
            labels[:new_context_len] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


### Function for Finetuning

In [6]:
def finetune_verifier(verifier, data, epochs=1, batch_size=4, lr=1e-5, warmup_ratio=0.1):
    """
    Finetune the verifier with a Learning Rate Scheduler.
    """
    # The Dataset initialization automatically handles the logic of appending the reference_answer
    dataset = VerifierDataset(data, verifier)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda batch: {
        "input_ids": torch.nn.utils.rnn.pad_sequence(
            [b["input_ids"] for b in batch],
            batch_first=True,
            padding_value=verifier.tokenizer.pad_token_id or verifier.tokenizer.eos_token_id
        ),
        "attention_mask": torch.nn.utils.rnn.pad_sequence(
            [b["attention_mask"] for b in batch],
            batch_first=True,
            padding_value=0
        ),
        "labels": torch.nn.utils.rnn.pad_sequence(
            [b["labels"] for b in batch],
            batch_first=True,
            padding_value=-100
        ),
    })

    model = verifier.model
    model.train()
    optimizer = AdamW(model.parameters(), lr=lr)

    total_steps = len(dataloader) * epochs
    num_warmup_steps = int(total_steps * warmup_ratio)

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=total_steps
    )

    print(f"Dataset size (augmented & flattened): {len(dataset)}")
    print(f"Training for {total_steps} steps with {num_warmup_steps} warmup steps.")

    for epoch in range(epochs):
        total_loss = 0.0
        for step, batch in enumerate(dataloader):
            batch = {k: v.to(model.device) for k, v in batch.items()}
            outputs = model(**batch)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

            current_lr = scheduler.get_last_lr()[0]
            if (step + 1) % 50 == 0:
                print(f"Epoch {epoch+1}, step {step+1}, loss = {total_loss / (step+1):.4f}, lr = {current_lr:.2e}")

        print(f"Epoch {epoch+1} finished, avg loss = {total_loss / len(dataloader):.4f}")

    model.eval()
    return verifier


### Finetuning Example

In [7]:
# 1. Setup Device (Prioritize GPU)
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

# 2. Initialize Verifier
# use gpt2 as a demo because it's small and fast.
model_name = "gpt2"

print(f"Loading model {model_name}...")
verifier = MathAnswerVerifier(
    model_name=model_name,
    device=device,
    label_yes=" y",
    label_no=" n"
)

# Some Decoder-only models (e.g., GPT2) don't have a pad_token by default; manually set it to eos_token
if verifier.tokenizer.pad_token is None:
    verifier.tokenizer.pad_token = verifier.tokenizer.eos_token
    verifier.model.config.pad_token_id = verifier.model.config.eos_token_id

# 3. Construct a Complex Dummy Dataset
# The updated VerifierDataset expects:
# - 'question': str
# - 'answers': List[str]
# - 'answer_labels': List[int] (0 or 1)
# - 'reference_answer': str (will be automatically added as a positive sample)
raw_train_data = [
    {
        "question": "What is 3 + 3?",
        "answers": [
            "The answer is 5.",    # Wrong
            "It is 6.",            # Correct
            "3 + 3 equals 7."      # Wrong
        ],
        "answer_labels": [0, 1, 0],
        "reference_answer": "6"
    },
    {
        "question": "Solve: 10 - 2",
        "answers": [
            "10 minus 2 is 8.",    # Correct
            "The result is 0."     # Wrong
        ],
        "answer_labels": [1, 0],
        "reference_answer": "8"
    },
    {
        "question": "Calculate 4 * 2",
        "answers": [
            "4 * 2 = 6",           # Wrong
            "It is 8"              # Correct
        ],
        "answer_labels": [0, 1],
        "reference_answer": "8"
    },
    {
        "question": "What is 10 / 2?",
        "answers": [
            "5",                   # Correct
            "2"                    # Wrong
        ],
        "answer_labels": [1, 0],
        "reference_answer": "5"
    }
]

# Duplicate data to simulate a larger dataset for the scheduler
# (In real scenarios, use more diverse data)
train_data = raw_train_data * 5
random.shuffle(train_data)

# 4. Test Case Definition (Unseen Data)
test_q = "What is 5 + 5?"
test_a_correct = "The answer is 10."
test_a_wrong = "The answer is 12."

print("\n" + "="*80)
print("PRE-TRAINING TEST")
print("="*80)

# Assuming 'verifier' is already initialized from your MathAnswerVerifier class
score_correct = verifier.score(test_q, test_a_correct)
score_wrong = verifier.score(test_q, test_a_wrong)

print(f"Q: {test_q}")
print(f"A: {test_a_correct} (Correct) -> Score: {score_correct:.4f}")
print(f"A: {test_a_wrong} (Wrong)   -> Score: {score_wrong:.4f}")

# 5. Start Fine-tuning
print("\n" + "="*80)
print("START FINE-TUNING")
print("="*80)

# Note: The dataset class will automatically:
# 1. Append 'reference_answer' to 'answers'
# 2. Append 1 to 'answer_labels'
# 3. Flatten the list so one question becomes multiple training samples
verifier = finetune_verifier(
    verifier,
    train_data,       # Passing the complex structure directly
    epochs=3,
    batch_size=2,     # Smaller batch size for this dummy example
    lr=5e-5,
    warmup_ratio=0.1,
)

# 6. Post-training Test
print("\n" + "="*80)
print("POST-TRAINING TEST")
print("="*80)

score_correct_post = verifier.score(test_q, test_a_correct)
score_wrong_post = verifier.score(test_q, test_a_wrong)

print(f"Q: {test_q}")
print(f"A: {test_a_correct} (Correct) -> Score: {score_correct_post:.4f}")
print(f"A: {test_a_wrong} (Wrong)   -> Score: {score_wrong_post:.4f}")

# 7. Test Generalization (Another Unseen Case)
print("\nTesting generalization (Unseen Data):")
unseen_q = "What is 6 * 2?"
unseen_a_corr = "It is 12."
unseen_a_wrong = "It is 100."

s_corr = verifier.score(unseen_q, unseen_a_corr)
s_wrong = verifier.score(unseen_q, unseen_a_wrong)
print(f"Q: {unseen_q}")
print(f"A: {unseen_a_corr} -> {s_corr:.4f}")
print(f"A: {unseen_a_wrong} -> {s_wrong:.4f}")

Using device: mps
Loading model gpt2...

PRE-TRAINING TEST
Q: What is 5 + 5?
A: The answer is 10. (Correct) -> Score: 0.4030
A: The answer is 12. (Wrong)   -> Score: 0.3993

START FINE-TUNING
Dataset size (augmented & flattened): 65
Training for 99 steps with 9 warmup steps.


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Epoch 1 finished, avg loss = 1.5966
Epoch 2 finished, avg loss = 0.5632
Epoch 3 finished, avg loss = 0.3404

POST-TRAINING TEST
Q: What is 5 + 5?
A: The answer is 10. (Correct) -> Score: 0.6403
A: The answer is 12. (Wrong)   -> Score: 0.6341

Testing generalization (Unseen Data):
Q: What is 6 * 2?
A: It is 12. -> 0.6514
A: It is 100. -> 0.5871
