# GSM8K: Comparing Reasoning Strategies

This notebook compares different sampling and evaluation strategies on the GSM8K (Grade School Math) benchmark.

**Strategies compared:**
- Greedy decoding
- Temperature sampling with self-consistency
- Best-of-N sampling
- Tree search (Best-first & MCTS)

**Metrics:**
- Accuracy
- Majority voting accuracy
- Agreement rate
- Compute efficiency (tokens/problem)

In [None]:
import sys
sys.path.insert(0, '..')

import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
from tqdm.notebook import tqdm

from src.models import load_model, VLLMModel
from src.datasets import GSM8KDataset
from src.samplers import (
    GreedySampler, StandardSampler, NucleusSampler, 
    DiverseSampler, BestFirstTreeSearch, MCTSTreeSearch
)
from src.evaluators import (
    AccuracyEvaluator, MajorityVotingEvaluator, 
    BestOfNEvaluator, WeightedVotingEvaluator
)
from src.runners import run_evaluation
from src.utils import save_results, load_results, compare_runs

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 12

## 1. Setup

In [None]:
# Configuration
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
MAX_PROBLEMS = 200  # Set to None for full evaluation
OUTPUT_DIR = "../results/gsm8k_comparison"

# Create output directory
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

In [None]:
# Load model (this may take a minute)
print("Loading model...")
model = VLLMModel(
    model_name=MODEL_NAME,
    gpu_memory_utilization=0.9,
    dtype="auto",
)
print(f"Model loaded: {model.name}")

In [None]:
# Load dataset
dataset = GSM8KDataset(split="test", use_cot_prompt=True)
print(f"Dataset: {dataset.name}, {len(dataset)} problems")

# Preview a problem
problems = dataset.get_problems(limit=MAX_PROBLEMS)
print(f"\nUsing {len(problems)} problems for evaluation")
print(f"\nExample problem:")
print(f"Prompt: {problems[0].prompt[:200]}...")
print(f"Gold answer: {problems[0].gold_answer}")

## 2. Define Experiments

We'll compare several strategies with different compute budgets.

In [None]:
# Define experimental configurations
experiments = {
    # Baseline: Single greedy sample
    "greedy": {
        "sampler": GreedySampler(max_tokens=2048),
        "evaluator": AccuracyEvaluator(dataset),
        "n_samples": 1,
        "description": "Greedy decoding (temperature=0)",
    },
    
    # Temperature sampling
    "temp_0.7": {
        "sampler": StandardSampler(temperature=0.7, top_p=0.95, max_tokens=2048),
        "evaluator": AccuracyEvaluator(dataset),
        "n_samples": 1,
        "description": "Temperature=0.7 single sample",
    },
    
    # Self-consistency with 8 samples
    "self_consistency_8": {
        "sampler": StandardSampler(temperature=0.7, top_p=0.95, max_tokens=2048),
        "evaluator": MajorityVotingEvaluator(dataset),
        "n_samples": 8,
        "description": "Self-consistency (8 samples, majority vote)",
    },
    
    # Self-consistency with 16 samples
    "self_consistency_16": {
        "sampler": StandardSampler(temperature=0.7, top_p=0.95, max_tokens=2048),
        "evaluator": MajorityVotingEvaluator(dataset),
        "n_samples": 16,
        "description": "Self-consistency (16 samples, majority vote)",
    },
    
    # Diverse sampling with self-consistency
    "diverse_consistency_8": {
        "sampler": DiverseSampler(
            temperatures=[0.3, 0.5, 0.7, 0.9, 1.0],
            top_p=0.95,
            max_tokens=2048
        ),
        "evaluator": MajorityVotingEvaluator(dataset),
        "n_samples": 8,
        "description": "Diverse temperatures + majority vote (8 samples)",
    },
    
    # Best-of-N by log probability
    "best_of_8": {
        "sampler": StandardSampler(temperature=0.7, top_p=0.95, max_tokens=2048),
        "evaluator": BestOfNEvaluator(dataset),
        "n_samples": 8,
        "description": "Best-of-8 by log probability",
    },
    
    # Weighted voting by probability
    "weighted_voting_8": {
        "sampler": StandardSampler(temperature=0.7, top_p=0.95, max_tokens=2048),
        "evaluator": WeightedVotingEvaluator(dataset),
        "n_samples": 8,
        "description": "Weighted voting by probability (8 samples)",
    },
    
    # Best-first tree search
    "best_first_tree": {
        "sampler": BestFirstTreeSearch(
            max_expansions=30,
            branch_factor=3,
            max_tokens=512,
            tokens_per_step=64,
            temperature=0.7,
        ),
        "evaluator": AccuracyEvaluator(dataset),
        "n_samples": 1,
        "description": "Best-first tree search",
    },
    
    # MCTS tree search
    "mcts_tree": {
        "sampler": MCTSTreeSearch(
            max_iterations=50,
            branch_factor=3,
            max_tokens=512,
            tokens_per_step=64,
            rollout_tokens=128,
            temperature=0.7,
        ),
        "evaluator": AccuracyEvaluator(dataset),
        "n_samples": 1,
        "description": "MCTS tree search",
    },
}

print(f"Defined {len(experiments)} experiments:")
for name, exp in experiments.items():
    print(f"  - {name}: {exp['description']}")

## 3. Run Experiments

In [None]:
# Store results
all_results = {}

for exp_name, exp_config in experiments.items():
    print(f"\n{'='*60}")
    print(f"Running: {exp_name}")
    print(f"Description: {exp_config['description']}")
    print(f"{'='*60}")
    
    start_time = datetime.now()
    
    results, metrics, responses, scores = run_evaluation(
        model=model,
        sampler=exp_config["sampler"],
        dataset=dataset,
        evaluator=exp_config["evaluator"],
        batch_size=4 if exp_config["n_samples"] > 1 else 8,
        n_samples=exp_config["n_samples"],
        max_problems=MAX_PROBLEMS,
        verbose=True,
    )
    
    elapsed = (datetime.now() - start_time).total_seconds()
    
    # Save results
    run_dir = save_results(
        output_dir=OUTPUT_DIR,
        run_name=exp_name,
        results=results,
        metrics=metrics,
        config={
            "experiment": exp_name,
            "description": exp_config["description"],
            "n_samples": exp_config["n_samples"],
            "model": MODEL_NAME,
        },
        responses=responses,
        scores=scores,
    )
    
    all_results[exp_name] = {
        "metrics": metrics,
        "results": results,
        "elapsed": elapsed,
        "run_dir": run_dir,
        "n_samples": exp_config["n_samples"],
    }
    
    print(f"\nCompleted in {elapsed:.1f}s")
    print(f"Accuracy: {metrics.accuracy:.4f}")

## 4. Compare Results

In [None]:
# Create comparison dataframe
import pandas as pd

comparison_data = []
for exp_name, data in all_results.items():
    metrics = data["metrics"]
    comparison_data.append({
        "Experiment": exp_name,
        "Description": experiments[exp_name]["description"],
        "Accuracy": metrics.accuracy,
        "Correct": metrics.correct,
        "Total": metrics.total,
        "N Samples": data["n_samples"],
        "Time (s)": data["elapsed"],
        "Time/Problem (s)": data["elapsed"] / metrics.total,
    })

df = pd.DataFrame(comparison_data)
df = df.sort_values("Accuracy", ascending=False)
df

In [None]:
# Accuracy comparison bar chart
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy
ax1 = axes[0]
colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(df)))
bars = ax1.barh(df["Experiment"], df["Accuracy"], color=colors)
ax1.set_xlabel("Accuracy")
ax1.set_title("GSM8K Accuracy by Strategy")
ax1.set_xlim(0, 1)

# Add value labels
for bar, acc in zip(bars, df["Accuracy"]):
    ax1.text(acc + 0.01, bar.get_y() + bar.get_height()/2, 
             f'{acc:.3f}', va='center', fontsize=10)

# Compute efficiency (accuracy per second)
ax2 = axes[1]
df["Efficiency"] = df["Accuracy"] / df["Time/Problem (s)"]
bars = ax2.barh(df["Experiment"], df["Efficiency"], color=colors)
ax2.set_xlabel("Accuracy / Time per Problem")
ax2.set_title("Compute Efficiency")

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/accuracy_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Accuracy vs Compute tradeoff
fig, ax = plt.subplots(figsize=(10, 6))

for exp_name, data in all_results.items():
    metrics = data["metrics"]
    ax.scatter(
        data["elapsed"] / metrics.total,  # Time per problem
        metrics.accuracy,
        s=100 * data["n_samples"],  # Size by n_samples
        label=exp_name,
        alpha=0.7,
    )
    ax.annotate(
        exp_name,
        (data["elapsed"] / metrics.total, metrics.accuracy),
        xytext=(5, 5),
        textcoords='offset points',
        fontsize=9,
    )

ax.set_xlabel("Time per Problem (seconds)")
ax.set_ylabel("Accuracy")
ax.set_title("Accuracy vs Compute Tradeoff\n(bubble size = number of samples)")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/accuracy_vs_compute.png", dpi=150, bbox_inches='tight')
plt.show()

## 5. Analyze Self-Consistency

In [None]:
# Analyze agreement rates for self-consistency methods
sc_experiments = ["self_consistency_8", "self_consistency_16", "diverse_consistency_8"]

for exp_name in sc_experiments:
    if exp_name not in all_results:
        continue
        
    results = all_results[exp_name]["results"]
    
    # Extract agreement info
    agreements = []
    for r in results:
        if r.metadata and "vote_counts" in r.metadata:
            vote_counts = r.metadata["vote_counts"]
            if vote_counts:
                max_votes = max(vote_counts.values())
                total = r.metadata.get("num_samples", sum(vote_counts.values()))
                agreements.append(max_votes / total)
    
    if agreements:
        print(f"\n{exp_name}:")
        print(f"  Mean agreement rate: {np.mean(agreements):.3f}")
        print(f"  Std agreement rate: {np.std(agreements):.3f}")
        print(f"  Min/Max: {np.min(agreements):.3f} / {np.max(agreements):.3f}")

In [None]:
# Agreement vs correctness analysis
if "self_consistency_16" in all_results:
    results = all_results["self_consistency_16"]["results"]
    
    correct_agreements = []
    incorrect_agreements = []
    
    for r in results:
        if r.metadata and "vote_counts" in r.metadata:
            vote_counts = r.metadata["vote_counts"]
            if vote_counts:
                max_votes = max(vote_counts.values())
                total = r.metadata.get("num_samples", 16)
                agreement = max_votes / total
                
                if r.correct:
                    correct_agreements.append(agreement)
                else:
                    incorrect_agreements.append(agreement)
    
    fig, ax = plt.subplots(figsize=(10, 5))
    
    ax.hist(correct_agreements, bins=20, alpha=0.7, label=f'Correct (n={len(correct_agreements)})', color='green')
    ax.hist(incorrect_agreements, bins=20, alpha=0.7, label=f'Incorrect (n={len(incorrect_agreements)})', color='red')
    
    ax.set_xlabel("Agreement Rate (max votes / total samples)")
    ax.set_ylabel("Count")
    ax.set_title("Self-Consistency: Agreement Rate Distribution by Correctness")
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/agreement_vs_correctness.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nMean agreement when correct: {np.mean(correct_agreements):.3f}")
    print(f"Mean agreement when incorrect: {np.mean(incorrect_agreements):.3f}")

## 6. Error Analysis

In [None]:
# Find problems where strategies disagree
greedy_results = {r.problem_id: r.correct for r in all_results["greedy"]["results"]}
sc_results = {r.problem_id: r.correct for r in all_results.get("self_consistency_8", {}).get("results", [])}

if sc_results:
    # Problems where self-consistency helped
    sc_helped = [pid for pid in greedy_results if not greedy_results[pid] and sc_results.get(pid, False)]
    # Problems where self-consistency hurt
    sc_hurt = [pid for pid in greedy_results if greedy_results[pid] and not sc_results.get(pid, True)]
    
    print(f"Self-consistency helped on {len(sc_helped)} problems")
    print(f"Self-consistency hurt on {len(sc_hurt)} problems")
    print(f"Net improvement: {len(sc_helped) - len(sc_hurt)} problems")

In [None]:
# Analyze difficult problems (wrong across all methods)
all_wrong = set()
for exp_name, data in all_results.items():
    wrong_ids = {r.problem_id for r in data["results"] if not r.correct}
    if not all_wrong:
        all_wrong = wrong_ids
    else:
        all_wrong = all_wrong.intersection(wrong_ids)

print(f"\nProblems wrong across ALL strategies: {len(all_wrong)}")

# Show a few examples
if all_wrong:
    print("\nExample hard problems:")
    for pid in list(all_wrong)[:3]:
        for p in problems:
            if p.id == pid:
                print(f"\n--- {pid} ---")
                print(f"Question: {p.prompt[:300]}...")
                print(f"Gold answer: {p.gold_answer}")
                break

## 7. Summary & Conclusions

In [None]:
# Final summary table
print("\n" + "="*80)
print("SUMMARY: GSM8K Strategy Comparison")
print("="*80)

summary_df = df[["Experiment", "Accuracy", "N Samples", "Time/Problem (s)"]].copy()
summary_df["Accuracy"] = summary_df["Accuracy"].apply(lambda x: f"{x:.4f}")
summary_df["Time/Problem (s)"] = summary_df["Time/Problem (s)"].apply(lambda x: f"{x:.2f}")

print(summary_df.to_string(index=False))

# Key findings
best_exp = df.iloc[0]
print(f"\n\nKey Findings:")
print(f"  1. Best accuracy: {best_exp['Experiment']} ({best_exp['Accuracy']:.4f})")
print(f"  2. Greedy baseline: {all_results['greedy']['metrics'].accuracy:.4f}")

if "self_consistency_8" in all_results:
    sc_acc = all_results["self_consistency_8"]["metrics"].accuracy
    greedy_acc = all_results["greedy"]["metrics"].accuracy
    print(f"  3. Self-consistency improvement: +{sc_acc - greedy_acc:.4f} ({(sc_acc - greedy_acc) / greedy_acc * 100:.1f}%)")

# Save summary
df.to_csv(f"{OUTPUT_DIR}/summary.csv", index=False)
print(f"\nResults saved to {OUTPUT_DIR}/")