# AIME: Competition Math Reasoning Strategies

This notebook compares different strategies on AIME (American Invitational Mathematics Examination) problems.

AIME is significantly harder than GSM8K:
- Competition-level problems
- Requires deeper reasoning
- Answers are integers 0-999

**Strategies compared:**
- Greedy decoding
- Self-consistency (majority voting)
- Tree search methods (Best-first, MCTS)
- Extended thinking (more tokens)

In [2]:
import sys
sys.path.insert(0, '..')

import json
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from datetime import datetime
from tqdm.notebook import tqdm

from src.models import VLLMModel
from src.datasets import AIMEDataset
from src.samplers import (
    GreedySampler, StandardSampler, DiverseSampler,
    BestFirstTreeSearch, MCTSTreeSearch
)
from src.evaluators import AccuracyEvaluator, MajorityVotingEvaluator
from src.runners import run_evaluation
from src.utils import save_results

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 12

## 1. Setup

In [3]:
# Configuration
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
MAX_PROBLEMS = 30  # AIME has ~450 problems total, but they're hard
OUTPUT_DIR = "../results/aime_comparison"

Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

In [4]:
# Load model
print("Loading model...")
model = VLLMModel(
    model_name=MODEL_NAME,
    gpu_memory_utilization=0.9,
    dtype="auto",
)
print(f"Model loaded: {model.name}")

Loading model...


  from .autonotebook import tqdm as notebook_tqdm


INFO 01-10 02:38:56 [importing.py:44] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 01-10 02:38:56 [importing.py:68] Triton not installed or not compatible; certain GPU-related functions will not be available.
INFO 01-10 02:39:00 [utils.py:253] non-default args: {'trust_remote_code': True, 'seed': 42, 'disable_log_stats': True, 'model': 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'}


RuntimeError: Device string must not be empty

In [None]:
# Load dataset
dataset = AIMEDataset(source="hf")
print(f"Dataset: {dataset.name}, {len(dataset)} problems")

problems = dataset.get_problems(limit=MAX_PROBLEMS)
print(f"\nUsing {len(problems)} problems for evaluation")

# Preview
print(f"\nExample problem:")
print(f"Prompt: {problems[0].prompt[:300]}...")
print(f"Gold answer: {problems[0].gold_answer}")

## 2. Define Experiments

For AIME, we expect tree search methods to be more beneficial due to problem difficulty.

In [None]:
experiments = {
    # Baseline
    "greedy": {
        "sampler": GreedySampler(max_tokens=2048),
        "evaluator": AccuracyEvaluator(dataset),
        "n_samples": 1,
        "description": "Greedy decoding",
    },
    
    # Extended thinking (more tokens)
    "greedy_long": {
        "sampler": GreedySampler(max_tokens=4096),
        "evaluator": AccuracyEvaluator(dataset),
        "n_samples": 1,
        "description": "Greedy with extended tokens (4096)",
    },
    
    # Self-consistency
    "self_consistency_8": {
        "sampler": StandardSampler(temperature=0.8, top_p=0.95, max_tokens=2048),
        "evaluator": MajorityVotingEvaluator(dataset),
        "n_samples": 8,
        "description": "Self-consistency (8 samples)",
    },
    
    # Higher temperature self-consistency (more exploration)
    "self_consistency_high_temp": {
        "sampler": StandardSampler(temperature=1.0, top_p=0.95, max_tokens=2048),
        "evaluator": MajorityVotingEvaluator(dataset),
        "n_samples": 8,
        "description": "Self-consistency (8 samples, temp=1.0)",
    },
    
    # Diverse sampling
    "diverse_16": {
        "sampler": DiverseSampler(
            temperatures=[0.5, 0.7, 0.9, 1.0, 1.2],
            top_p=0.95,
            max_tokens=2048
        ),
        "evaluator": MajorityVotingEvaluator(dataset),
        "n_samples": 16,
        "description": "Diverse temperatures (16 samples)",
    },
    
    # Best-first tree search (smaller budget)
    "best_first_small": {
        "sampler": BestFirstTreeSearch(
            max_expansions=20,
            branch_factor=3,
            max_tokens=512,
            tokens_per_step=64,
            temperature=0.8,
        ),
        "evaluator": AccuracyEvaluator(dataset),
        "n_samples": 1,
        "description": "Best-first tree (20 expansions)",
    },
    
    # Best-first tree search (larger budget)
    "best_first_large": {
        "sampler": BestFirstTreeSearch(
            max_expansions=50,
            branch_factor=4,
            max_tokens=768,
            tokens_per_step=48,
            temperature=0.8,
        ),
        "evaluator": AccuracyEvaluator(dataset),
        "n_samples": 1,
        "description": "Best-first tree (50 expansions)",
    },
    
    # MCTS (smaller budget)
    "mcts_small": {
        "sampler": MCTSTreeSearch(
            max_iterations=30,
            branch_factor=3,
            max_tokens=512,
            tokens_per_step=64,
            rollout_tokens=128,
            temperature=0.8,
            exploration_constant=1.5,
        ),
        "evaluator": AccuracyEvaluator(dataset),
        "n_samples": 1,
        "description": "MCTS (30 iterations)",
    },
    
    # MCTS (larger budget, more exploration)
    "mcts_large": {
        "sampler": MCTSTreeSearch(
            max_iterations=80,
            branch_factor=4,
            max_tokens=768,
            tokens_per_step=48,
            rollout_tokens=128,
            temperature=0.8,
            exploration_constant=2.0,  # More exploration
        ),
        "evaluator": AccuracyEvaluator(dataset),
        "n_samples": 1,
        "description": "MCTS (80 iterations, high exploration)",
    },
}

print(f"Defined {len(experiments)} experiments:")
for name, exp in experiments.items():
    print(f"  - {name}: {exp['description']}")

## 3. Run Experiments

In [None]:
all_results = {}

for exp_name, exp_config in experiments.items():
    print(f"\n{'='*60}")
    print(f"Running: {exp_name}")
    print(f"Description: {exp_config['description']}")
    print(f"{'='*60}")
    
    start_time = datetime.now()
    
    # Tree search methods need batch_size=1
    is_tree = "tree" in exp_name or "mcts" in exp_name
    batch_size = 1 if is_tree else (2 if exp_config["n_samples"] > 1 else 4)
    
    results, metrics, responses, scores = run_evaluation(
        model=model,
        sampler=exp_config["sampler"],
        dataset=dataset,
        evaluator=exp_config["evaluator"],
        batch_size=batch_size,
        n_samples=exp_config["n_samples"],
        max_problems=MAX_PROBLEMS,
        verbose=True,
    )
    
    elapsed = (datetime.now() - start_time).total_seconds()
    
    run_dir = save_results(
        output_dir=OUTPUT_DIR,
        run_name=exp_name,
        results=results,
        metrics=metrics,
        config={
            "experiment": exp_name,
            "description": exp_config["description"],
            "n_samples": exp_config["n_samples"],
            "model": MODEL_NAME,
        },
        responses=responses,
        scores=scores,
    )
    
    all_results[exp_name] = {
        "metrics": metrics,
        "results": results,
        "elapsed": elapsed,
        "n_samples": exp_config["n_samples"],
    }
    
    print(f"\nCompleted in {elapsed:.1f}s")
    print(f"Accuracy: {metrics.accuracy:.4f} ({metrics.correct}/{metrics.total})")

## 4. Compare Results

In [None]:
# Create comparison dataframe
comparison_data = []
for exp_name, data in all_results.items():
    metrics = data["metrics"]
    comparison_data.append({
        "Experiment": exp_name,
        "Description": experiments[exp_name]["description"],
        "Accuracy": metrics.accuracy,
        "Correct": metrics.correct,
        "Total": metrics.total,
        "N Samples": data["n_samples"],
        "Time (s)": data["elapsed"],
        "Time/Problem (s)": data["elapsed"] / metrics.total,
    })

df = pd.DataFrame(comparison_data)
df = df.sort_values("Accuracy", ascending=False)
df

In [None]:
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Accuracy comparison
ax1 = axes[0]
colors = ['#2ecc71' if 'mcts' in name or 'tree' in name else 
          '#3498db' if 'consistency' in name or 'diverse' in name else 
          '#e74c3c' for name in df["Experiment"]]
bars = ax1.barh(df["Experiment"], df["Accuracy"], color=colors)
ax1.set_xlabel("Accuracy")
ax1.set_title("AIME Accuracy by Strategy")
ax1.set_xlim(0, max(df["Accuracy"]) * 1.2 if df["Accuracy"].max() > 0 else 0.3)

for bar, acc in zip(bars, df["Accuracy"]):
    ax1.text(acc + 0.005, bar.get_y() + bar.get_height()/2, 
             f'{acc:.3f}', va='center', fontsize=10)

# Time comparison
ax2 = axes[1]
bars = ax2.barh(df["Experiment"], df["Time/Problem (s)"], color=colors)
ax2.set_xlabel("Time per Problem (seconds)")
ax2.set_title("Compute Time by Strategy")

# Legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='#2ecc71', label='Tree Search'),
    Patch(facecolor='#3498db', label='Self-Consistency'),
    Patch(facecolor='#e74c3c', label='Greedy'),
]
ax1.legend(handles=legend_elements, loc='lower right')

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/aime_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Accuracy vs Time scatter
fig, ax = plt.subplots(figsize=(10, 7))

# Group by method type
for exp_name, data in all_results.items():
    metrics = data["metrics"]
    
    if 'mcts' in exp_name:
        color, marker = '#2ecc71', 's'  # Green square
    elif 'tree' in exp_name:
        color, marker = '#27ae60', '^'  # Darker green triangle
    elif 'consistency' in exp_name or 'diverse' in exp_name:
        color, marker = '#3498db', 'o'  # Blue circle
    else:
        color, marker = '#e74c3c', 'D'  # Red diamond
    
    ax.scatter(
        data["elapsed"] / metrics.total,
        metrics.accuracy,
        s=150,
        c=color,
        marker=marker,
        alpha=0.8,
        edgecolors='black',
        linewidth=1,
    )
    ax.annotate(
        exp_name.replace('_', '\n'),
        (data["elapsed"] / metrics.total, metrics.accuracy),
        xytext=(8, 0),
        textcoords='offset points',
        fontsize=8,
        va='center',
    )

ax.set_xlabel("Time per Problem (seconds)", fontsize=12)
ax.set_ylabel("Accuracy", fontsize=12)
ax.set_title("AIME: Accuracy vs Compute Tradeoff", fontsize=14)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/aime_pareto.png", dpi=150, bbox_inches='tight')
plt.show()

## 5. Problem Difficulty Analysis

In [None]:
# Analyze which problems each method solves
method_correct = {}
for exp_name, data in all_results.items():
    method_correct[exp_name] = {r.problem_id: r.correct for r in data["results"]}

# Count how many methods solve each problem
problem_solve_counts = {}
for pid in method_correct["greedy"].keys():
    count = sum(1 for method in method_correct.values() if method.get(pid, False))
    problem_solve_counts[pid] = count

# Distribution
counts = list(problem_solve_counts.values())
n_methods = len(all_results)

fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(counts, bins=range(n_methods + 2), align='left', rwidth=0.8, color='steelblue')
ax.set_xlabel("Number of Methods that Solved Problem")
ax.set_ylabel("Number of Problems")
ax.set_title("Problem Difficulty Distribution")
ax.set_xticks(range(n_methods + 1))

plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/problem_difficulty.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"Problems solved by ALL methods: {sum(1 for c in counts if c == n_methods)}")
print(f"Problems solved by NO method: {sum(1 for c in counts if c == 0)}")
print(f"Problems solved by SOME methods: {sum(1 for c in counts if 0 < c < n_methods)}")

In [None]:
# Find problems uniquely solved by tree search
tree_methods = [name for name in all_results.keys() if 'mcts' in name or 'tree' in name]
other_methods = [name for name in all_results.keys() if name not in tree_methods]

unique_tree_solves = []
for pid in method_correct["greedy"].keys():
    tree_solved = any(method_correct[m].get(pid, False) for m in tree_methods)
    other_solved = any(method_correct[m].get(pid, False) for m in other_methods)
    if tree_solved and not other_solved:
        unique_tree_solves.append(pid)

print(f"\nProblems uniquely solved by tree search methods: {len(unique_tree_solves)}")

# Show examples
if unique_tree_solves:
    print("\nExample problems uniquely solved by tree search:")
    for pid in unique_tree_solves[:2]:
        for p in problems:
            if p.id == pid:
                print(f"\n--- {pid} ---")
                print(f"Question: {p.prompt[:400]}...")
                print(f"Gold answer: {p.gold_answer}")

## 6. Summary

In [None]:
print("\n" + "="*80)
print("SUMMARY: AIME Strategy Comparison")
print("="*80)

summary_df = df[["Experiment", "Accuracy", "Correct", "Total", "Time/Problem (s)"]].copy()
summary_df["Accuracy"] = summary_df["Accuracy"].apply(lambda x: f"{x:.4f}")
summary_df["Time/Problem (s)"] = summary_df["Time/Problem (s)"].apply(lambda x: f"{x:.2f}")

print(summary_df.to_string(index=False))

# Analysis
best_exp = df.iloc[0]
greedy_acc = all_results["greedy"]["metrics"].accuracy

print(f"\n\nKey Findings:")
print(f"  1. Best accuracy: {best_exp['Experiment']} ({best_exp['Accuracy']})")
print(f"  2. Greedy baseline: {greedy_acc:.4f}")
print(f"  3. Best improvement over greedy: +{float(best_exp['Accuracy']) - greedy_acc:.4f}")

# Tree search analysis
tree_accs = [all_results[m]["metrics"].accuracy for m in tree_methods if m in all_results]
other_accs = [all_results[m]["metrics"].accuracy for m in other_methods if m in all_results]

if tree_accs and other_accs:
    print(f"\n  Tree search methods avg accuracy: {np.mean(tree_accs):.4f}")
    print(f"  Other methods avg accuracy: {np.mean(other_accs):.4f}")

df.to_csv(f"{OUTPUT_DIR}/summary.csv", index=False)
print(f"\nResults saved to {OUTPUT_DIR}/")