## Step 1: Load Model and Dataset


In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import re

# Load the tokenizer and the model
model_name = "Qwen/Qwen3-0.6B"

print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)

# Load GSM8K dataset
print("Loading GSM8K dataset...")
ds = load_dataset("openai/gsm8k", "main")

print(f"Dataset loaded: {len(ds['test'])} test examples, {len(ds['train'])} train examples")
print(f"Model loaded on device: {model.device}")


Loading tokenizer and model...


2025-11-23 20:50:31.412822: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-23 20:50:31.640679: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-23 20:50:32.845433: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading GSM8K dataset...
Dataset loaded: 1319 test examples, 7473 train examples
Model loaded on device: cuda:0


: 

In [2]:
# Check device information
print("="*60)
print("DEVICE INFORMATION")
print("="*60)

# Check if CUDA is available
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    print(f"Current GPU device: {torch.cuda.current_device()}")
    print(f"GPU device name: {torch.cuda.get_device_name(0)}")
    
    # Memory information
    print(f"\nGPU Memory:")
    print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"  Reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
else:
    print("Running on CPU")

# Check model device
print(f"\nModel device: {model.device}")
print(f"Model dtype: {model.dtype}")

# Check which device each parameter is on (for distributed models)
devices = set()
for name, param in model.named_parameters():
    devices.add(str(param.device))

if len(devices) > 1:
    print(f"\nModel is distributed across multiple devices: {devices}")
else:
    print(f"\nAll model parameters are on: {list(devices)[0]}")

print("="*60)


DEVICE INFORMATION
CUDA available: True
CUDA version: 12.6
Number of GPUs: 1
Current GPU device: 0
GPU device name: NVIDIA GeForce RTX 3070 Ti Laptop GPU

GPU Memory:
  Allocated: 1.11 GB
  Reserved: 1.40 GB

Model device: cuda:0
Model dtype: torch.bfloat16

All model parameters are on: cuda:0


In [3]:
import torch; print(torch.version.cuda); print(torch.backends.cudnn.version()); print(torch.cuda.is_available())

12.6
90501
True


## Step 2: Helper Functions for Answer Extraction and Verification


In [4]:
def extract_answer(text):
    """
    Extract the numerical answer from the text.
    GSM8K answers typically end with #### followed by the number.
    """
    # Try to find the answer after ####
    match = re.search(r'####\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)', text)
    if match:
        # Remove commas from the number
        return match.group(1).replace(',', '')
    
    # Fallback: try to find the last number in the text
    numbers = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text)
    if numbers:
        return numbers[-1].replace(',', '')
    
    return None

def check_answer_correct(generated_answer, reference_answer):
    """
    Check if the generated answer matches the reference answer.
    """
    gen = extract_answer(generated_answer)
    ref = extract_answer(reference_answer)
    
    if gen is None or ref is None:
        return False
    
    try:
        # Compare as floats to handle different formats
        return abs(float(gen) - float(ref)) < 0.01
    except:
        return gen == ref

# Test the extraction function
test_answer = "So the total is 50 + 30 = 80. #### 80"
print(f"Test extraction: '{test_answer}' -> {extract_answer(test_answer)}")


Test extraction: 'So the total is 50 + 30 = 80. #### 80' -> 80


## Step 3: Generate Multiple Answers for a Question


In [None]:
def generate_answers(question, num_answers=10, max_new_tokens=512, temperature=0.7):
    """
    Generate multiple different answers to the same question using Qwen chat template.
    """
    # Format the prompt with explicit instruction to use GSM8K format
    prompt_text = f"""Question: {question}
Answer: Let's solve this step by step concisely. End your answer with #### followed by the final numerical answer."""
    
    # Use Qwen chat template format
    messages = [
        {"role": "user", "content": prompt_text}
    ]
    
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False  # Disable thinking mode for faster generation
    )
    
    # Tokenize
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    # Generate multiple answers
    answers = []
    for i in range(num_answers):
        print(f"\n{'='*60}")
        print(f"Generating answer {i+1}/{num_answers}...")
        print('='*60)
        
        with torch.no_grad():
            generated_ids = model.generate(
                **model_inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id
            )
        
        # Decode only the generated part (not the input prompt)
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
        generated_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
        
        answers.append(generated_text)
        
        # Print the result after each trial
        print(f"\nGenerated Answer {i+1}:")
        print('-'*60)
        print(generated_text )
        print('-'*60)
        extracted = extract_answer(generated_text)
        print(f"Extracted Answer: {extracted}")
    
    return answers


## Step 4: Generate and Verify Answers


In [14]:
# Test with the first question from the test set
test_example = ds['test'][0]
print(f"Question: {test_example['question']}")
print(f"\nReference Answer: {test_example['answer']}")
print(f"Reference Final Answer: {extract_answer(test_example['answer'])}")
print("\n" + "="*80)
# Generate answers for the first question
print("Generating different answers...")
num_answers = 3
generated_answers = generate_answers(test_example['question'], num_answers=num_answers)

# Check which answers are correct
print("\n" + "="*80)
print("VERIFICATION RESULTS:")
print("="*80)

correct_answers = []
for i, answer in enumerate(generated_answers):
    extracted = extract_answer(answer)
    is_correct = check_answer_correct(answer, test_example['answer'])
    
    print(f"\nAnswer {i+1}:")
    print(f"Extracted value: {extracted}")
    print(f"Correct: {is_correct}")
    print(f"Response preview: {answer[:200]}...")
    
    if is_correct:
        correct_answers.append(i)

print("\n" + "="*80)
print(f"Summary: {len(correct_answers)}/{num_answers} answers were correct")
print(f"Correct answer indices: {correct_answers}")


Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?

Reference Answer: Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.
She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.
#### 18
Reference Final Answer: 18

Generating different answers...

Generating answer 1/3...

Generated Answer 1:
------------------------------------------------------------
To find how much Janet makes at the farmers' market every day:

- She **eats 3 eggs** for breakfast.
- She **bakes 4 eggs** for friends.
- She **lays 16 eggs** per day.
- She **sells the remaining eggs** at $2 per egg.

**Total eggs per day:**
$$ 16 \text{ eggs/day} $$

**Eggs eaten: 3 (breakfast) + 4 (baking) = 7 eggs**

**Remaining eggs:**
$$ 16 - 7 = 9 \text{ eggs/day} $$

**Earnings fr

## Step 5: Score Answers Using the LLM (Verifier)


In [15]:
def score_answer(question, answer, max_new_tokens=100):
    """
    Use the LLM to score/verify an answer using chat template.
    Returns the log probability or a confidence score.
    """
    # Create a verification prompt using chat template
    prompt_text = f"""Question: {question}
Answer: {answer}

Is this answer correct? Respond with yes or no."""
    
    messages = [
        {"role": "user", "content": prompt_text}
    ]
    
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )
    
    # Tokenize
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    # Get the model's confidence by computing log probabilities
    with torch.no_grad():
        outputs = model(**model_inputs)
        logits = outputs.logits
        
        # Calculate the average log probability of the answer tokens
        # This is a simple scoring mechanism
        log_probs = torch.log_softmax(logits, dim=-1)
        
        # Get the average log probability over all tokens in the sequence
        input_ids = model_inputs['input_ids'][0]
        
        if len(input_ids) > 1:
            token_log_probs = []
            for i in range(len(input_ids) - 1):
                token_log_prob = log_probs[0, i, input_ids[i+1]].item()
                token_log_probs.append(token_log_prob)
            
            avg_log_prob = sum(token_log_probs) / len(token_log_probs) if token_log_probs else 0
        else:
            avg_log_prob = 0
    
    return avg_log_prob



In [16]:
# Score all generated answers
print("Scoring all generated answers...")
scores = []
for i, answer in enumerate(generated_answers):
    score = score_answer(test_example['question'], answer)
    scores.append(score)
    is_correct = check_answer_correct(answer, test_example['answer'])
    print(f"Answer {i+1}: Score = {score:.4f}, Correct = {is_correct}")

# Find the best answer according to the verifier
best_idx = scores.index(max(scores))
print(f"\n" + "="*80)
print(f"Best answer according to verifier: Answer {best_idx + 1}")
print(f"Is the best answer correct? {check_answer_correct(generated_answers[best_idx], test_example['answer'])}")
print(f"Best answer: {generated_answers[best_idx][:300]}...")


Scoring all generated answers...
Answer 1: Score = -1.6916, Correct = True
Answer 2: Score = -2.1462, Correct = True
Answer 3: Score = -1.9016, Correct = False

Best answer according to verifier: Answer 1
Is the best answer correct? True
Best answer: To find how much Janet makes at the farmers' market every day:

- She **eats 3 eggs** for breakfast.
- She **bakes 4 eggs** for friends.
- She **lays 16 eggs** per day.
- She **sells the remaining eggs** at $2 per egg.

**Total eggs per day:**
$$ 16 \text{ eggs/day} $$

**Eggs eaten: 3 (breakfast) +...


## Step 6: Run Experiment on Multiple Questions


In [25]:
def run_experiment(num_questions=5, num_answers=10, max_new_tokens=512):
    """
    Run the complete experiment on multiple questions.
    """
    results = {
        'questions': [],
        'base_correct': [],  # First answer correct?
        'any_correct': [],   # Any of the 10 correct?
        'best_of_n_correct': [],  # Best-of-N answer correct?
        'num_correct': []    # How many out of 10 were correct?
    }
    
    for q_idx in range(num_questions):
        print(f"\n{'='*80}")
        print(f"Processing Question {q_idx + 1}/{num_questions}")
        print(f"{'='*80}")
        
        example = ds['test'][q_idx]
        question = example['question']
        reference = example['answer']
        
        print(f"Question: {question[:100]}...")
        
        # Generate multiple answers
        answers = generate_answers(question, num_answers=num_answers, max_new_tokens=max_new_tokens)
        
        # Check correctness
        correctness = [check_answer_correct(ans, reference) for ans in answers]
        
        # Score answers
        print("\nScoring answers...")
        answer_scores = [score_answer(question, ans) for ans in answers]
        
        # Get best answer
        best_idx = answer_scores.index(max(answer_scores))
        
        # Store results
        results['questions'].append(question)
        results['base_correct'].append(correctness[0])
        results['any_correct'].append(any(correctness))
        results['best_of_n_correct'].append(correctness[best_idx])
        results['num_correct'].append(sum(correctness))
        
        print(f"\nResults for Question {q_idx + 1}:")
        print(f"  - Base model (first answer): {'✓' if correctness[0] else '✗'}")
        print(f"  - Any correct: {'✓' if any(correctness) else '✗'}")
        print(f"  - Best-of-{num_answers}: {'✓' if correctness[best_idx] else '✗'}")
        print(f"  - Total correct: {sum(correctness)}/{num_answers}")
    
    return results

# Run the experiment
print("Starting experiment...")
experiment_results = run_experiment(num_questions=5, num_answers=10, max_new_tokens=512)


Starting experiment...

Processing Question 1/5
Question: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for ...

Generating answer 1/10...

Generated Answer 1:
------------------------------------------------------------
Janet's daily earnings at the farmers' market:

- She eats **3 eggs** for breakfast.
- She bakes **4 eggs** for her friends.
- She sells the **remainder** for **$2 per egg**.

**Eggs laid per day = 16**  
**Eggs eaten = 3**  
**Eggs bakes = 4**

Remaining eggs = $16 - 3 - 4 = 9$

**Earnings = 9 eggs × $2 = $18**

Answer: #### 18
------------------------------------------------------------
Extracted Answer: 18

Generating answer 2/10...

Generated Answer 2:
------------------------------------------------------------
Janet's ducks lay **16 eggs per day**, she eats **3 eggs** for breakfast, and she bakes **4 eggs** for her friends. The remaining eggs are sold at **$2 per egg**.

- **Eggs eaten for breakfast**: 3 eggs  
- 

## Step 7: Analyze Results


In [26]:
# Calculate statistics
base_accuracy = sum(experiment_results['base_correct']) / len(experiment_results['base_correct'])
best_of_n_accuracy = sum(experiment_results['best_of_n_correct']) / len(experiment_results['best_of_n_correct'])
any_correct_rate = sum(experiment_results['any_correct']) / len(experiment_results['any_correct'])
avg_correct = sum(experiment_results['num_correct']) / len(experiment_results['num_correct'])

print("\n" + "="*80)
print("FINAL RESULTS")
print("="*80)
print(f"Number of questions tested: {len(experiment_results['questions'])}")
print(f"\nBase model accuracy (first answer): {base_accuracy:.2%}")
print(f"Best-of-10 accuracy (verifier-selected): {best_of_n_accuracy:.2%}")
print(f"Oracle best-of-10 (any correct): {any_correct_rate:.2%}")
print(f"Average correct answers per question: {avg_correct:.2f}/10")

print("\n" + "="*80)
print("RESEARCH QUESTION INSIGHTS:")
print("="*80)
print(f"Q: Are answers with higher scores more likely to be correct?")
print(f"A: Best-of-10 accuracy ({best_of_n_accuracy:.2%}) vs Base accuracy ({base_accuracy:.2%})")
print(f"   Improvement: {best_of_n_accuracy - base_accuracy:+.2%}")
print(f"\nQ: Is the average best-of-10 answer more accurate than the average base answer?")
print(f"A: {'YES' if best_of_n_accuracy > base_accuracy else 'NO'} - Best-of-10 is {best_of_n_accuracy/base_accuracy:.2f}x better" if base_accuracy > 0 else "A: Need non-zero base accuracy to compare")

# Detailed breakdown
print("\n" + "="*80)
print("DETAILED BREAKDOWN BY QUESTION:")
print("="*80)
for i in range(len(experiment_results['questions'])):
    print(f"\nQ{i+1}: {experiment_results['questions'][i][:60]}...")
    print(f"  Base: {'✓' if experiment_results['base_correct'][i] else '✗'} | "
          f"Best-of-10: {'✓' if experiment_results['best_of_n_correct'][i] else '✗'} | "
          f"Correct: {experiment_results['num_correct'][i]}/10")



FINAL RESULTS
Number of questions tested: 5

Base model accuracy (first answer): 60.00%
Best-of-10 accuracy (verifier-selected): 60.00%
Oracle best-of-10 (any correct): 80.00%
Average correct answers per question: 3.80/10

RESEARCH QUESTION INSIGHTS:
Q: Are answers with higher scores more likely to be correct?
A: Best-of-10 accuracy (60.00%) vs Base accuracy (60.00%)
   Improvement: +0.00%

Q: Is the average best-of-10 answer more accurate than the average base answer?
A: NO - Best-of-10 is 1.00x better

DETAILED BREAKDOWN BY QUESTION:

Q1: Janet’s ducks lay 16 eggs per day. She eats three for breakf...
  Base: ✓ | Best-of-10: ✓ | Correct: 5/10

Q2: A robe takes 2 bolts of blue fiber and half that much white ...
  Base: ✓ | Best-of-10: ✓ | Correct: 8/10

Q3: Josh decides to try flipping a house.  He buys a house for $...
  Base: ✗ | Best-of-10: ✗ | Correct: 1/10

Q4: James decides to run 3 sprints 3 times a week.  He runs 60 m...
  Base: ✗ | Best-of-10: ✗ | Correct: 0/10

Q5: Every da