# Part 1

Throughout the notebook, we will explore several inference methods, including:

- **Chain-of-Thought (CoT):** A method where the model generates intermediate reasoning steps before providing the final answer.
- **Best-of-n Sampling:** An approach that generates multiple candidate responses and selects the best one based on a scoring function.
- **Beam Search:** A technique that expands several possible sequences simultaneously, choosing the most promising ones based on probability.
- **Self-Refinement:** An iterative process where the model revises its output to improve accuracy and coherence.

The **Math Benchmark** is a suite of challenging mathematical problems designed to test the reasoning and problem-solving capabilities of LLMs. The benchmark includes a variety of questions ranging from basic arithmetic and algebra to more advanced topics such as geometry and calculus. For example, you might be asked to solve an equation like `2x + 5 = 15` or compute the derivative of a function, tasks that assess the model's ability to handle both straightforward and complex mathematical queries.

Let's dive into the notebook and begin exploring how these methods perform on a challenging set of math problems!


# installing Dependencies

In [None]:
!pip install -q vllm
!pip install -q transformers accelerate datasets

from IPython.display import clear_output
clear_output()

* You should use this cell if you're running the notebook on Google Colab. If you're using Kaggle, you don't need to run this cell.

In [None]:
!pip install --upgrade numpy
import os
os.kill(os.getpid(), 9)

## vLLM: Accelerated Inference Engine for LLMs

vLLM is an open-source project designed to optimize the loading and inference of large language models. By leveraging advanced memory management techniques and dynamic batching, vLLM significantly speeds up the inference process, making it easier to deploy and experiment with LLMs even on hardware with limited resources
So we use vLLM to get results faster.

## VLLM Server Setup and Initialization

In this section, we install the required packages, ensure that only one server instance is running, and start the VLLM server using the model `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B`.

**Installation and Cleanup:**
- The necessary packages (`vllm`, `transformers`, `accelerate`, and `datasets`) are installed in a cell with hidden output to keep the notebook clean.
- Any previously running VLLM server instances are terminated before starting a new one. This prevents multiple servers from running simultaneously.

**Server Initialization:**
- The server is launched as a background process using `subprocess.Popen`.
- **Initialization Time:**  
  The server typically takes about **1 minute** to fully initialize.
- **GPU Memory Utilization:**  
  Monitor your GPU memory usage. Initially, it will be at **0 GB**, and then it will gradually increase until it reaches approximately **12 GB** when the server is fully up and running.

Please wait until the GPU memory stabilizes around **12 GB** before proceeding to the next steps.


In [None]:
import subprocess
import time

# Kill any running VLLM server instances for the specified model
kill_cmd = "pkill -f 'vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'"
subprocess.run(kill_cmd, shell=True)

# Command to start the VLLM server
cmd = [
    "vllm", "serve", "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    "--port", "8000", "--dtype=half", "--max-model-len", "5192"
]

print('============= run vllm server')
server_process = subprocess.Popen(
    cmd,
    stdin=subprocess.DEVNULL,
    stdout=subprocess.DEVNULL,
    stderr=subprocess.DEVNULL
)


print('============= wait for 200 sec')
time.sleep(200)

print("============= Server started!")



* you can debug last cell if doesn't work right with this cell (if that works you DO NOT run this cell)

In [None]:
!vllm serve "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"   --port 8000   --dtype=half   --max-model-len 5192

# Helper Functions Overview

This section contains a series of helper functions designed to facilitate the evaluation of mathematical problem solving using the MATH-500 dataset and a local LLM server. These functions handle tasks such as dataset loading, answer extraction, normalization of various mathematical expressions, answer comparison, and result management. Below is an explanation of each group of functions:

---

## Dataset Loading

- **`load_math500_dataset()`**  
  Loads the test split of the MATH-500 dataset from the Hugging Face repository (`HuggingFaceH4/MATH-500`). This dataset provides the math problems and corresponding solutions used for evaluation.

---

## Answer Extraction

- **`extract_answer(response: str) -> Optional[str]`**  
  Searches the provided text for the last occurrence of the LaTeX command `\boxed{...}` and extracts the content within it. This function is essential for retrieving the final answer from the formatted solutions.

---

## Normalization Functions

Normalization functions standardize the format of mathematical expressions to enable accurate comparisons between predicted and correct answers. These functions account for various representations, ensuring that equivalent answers written in different formats are recognized as equal.

- **`normalize_number(num_str: str) -> str`**  
  Cleans and normalizes numeric strings by removing extraneous characters (e.g., commas, currency symbols, and measurement units) and formatting them into a consistent number format.

- **`numerically_equal(str1: str, str2: str) -> bool`**  
  Checks if two numeric strings represent the same value within a small tolerance, accounting for floating point precision issues.

- **`normalize_fraction(fraction_str: str) -> str`**  
  Converts various representations of fractions (with or without braces or using a slash) into a standard LaTeX format: `\frac{numerator}{denominator}`.

- **`normalize_matrix_entry(entry: str) -> str`**  
  Standardizes individual matrix entries, especially handling fractions and slash-separated numbers, to ensure consistency within matrix representations.

- **`normalize_matrix(matrix_str: str) -> str`**  
  Processes a LaTeX matrix (formatted with `\begin{pmatrix}` and `\end{pmatrix}`) by normalizing each row and each entry using the matrix entry normalization.

- **`normalize_algebraic_expression(expr: str) -> str`**  
  Standardizes algebraic expressions by handling coefficients, variables, exponents, and special terms like π (pi). This helps compare algebraic answers regardless of minor formatting differences.

- **`normalize_interval_bound(bound: str) -> str`**  
  Normalizes the boundary of an interval, ensuring that symbols like infinity (`\infty`) and other numeric boundaries are consistently formatted.

- **`normalize_interval(interval_str: str) -> str`**  
  Standardizes an interval provided in LaTeX, ensuring that both bounds are normalized and that the overall format (including brackets) is consistent.

- **`normalize_ordered_tuple(tuple_str: str) -> str`**  
  Normalizes an ordered tuple by splitting its elements and applying answer normalization to each component, ensuring a standard tuple representation.

- **`normalize_answer(answer: str) -> str`**  
  The central normalization function that applies the various normalization steps to a given answer. It cleans up LaTeX formatting, removes unnecessary spaces, and calls the specialized normalization functions to standardize numeric, fractional, algebraic, and other mathematical expressions.

---

## Answer Comparison

- **`compare_answers(correct_answer: str, predicted_answer: Optional[str]) -> bool`**  
  Compares the normalized versions of the correct answer and the predicted answer. This function ensures that answers are compared in a standardized format so that minor differences in formatting do not affect the evaluation outcome.

---

## Result Management Functions

These functions handle saving and analyzing the results of the evaluation process.

- **`load_existing_results(filename: str) -> list[Dict]`**  
  Loads previously saved evaluation results from a JSON file. If the file does not exist, it returns an empty list.

- **`save_result(filename: str, result: Dict)`**  
  Appends a single evaluation result (including problem details, the LLM response, and correctness) to the results file in JSON format.

- **`analyze_results(results: list[Dict])`**  
  Analyzes the evaluation outcomes by summarizing the total number of problems, counting the correct answers, calculating the accuracy, and printing details for any problems that were answered incorrectly.

---

## Main Evaluation and Response Handling

- **`evaluate()`**  
  The primary function that orchestrates the evaluation process:
  - Creates a results directory if it doesn't already exist.
  - Loads the MATH-500 dataset.
  - Iterates over each problem (while skipping already processed ones).
  - Sends the problem text to the local LLM server using `get_llm_response`.
  - Extracts and compares the answers, then saves the result.
  - Finally, it analyzes and prints a summary of the evaluation.

- **`get_llm_response(prompt: str) -> str`**  
  Sends a prompt to the locally running LLM server (via an HTTP POST request to `http://localhost:8000/v1/chat/completions`) and returns the server's response. This function is key to obtaining the model's predicted answer.


In [None]:
import json
import os
import re
from typing import Dict, Optional, Union
from datasets import load_dataset
from tqdm import tqdm
import torch

# Load the MATH-500 dataset
def load_math500_dataset():
    dataset = load_dataset("HuggingFaceH4/MATH-500")["test"]
    return dataset

# Extract the last boxed answer from text
def extract_answer(response: str) -> Optional[str]:
    if not response:
        return None
    start_idx = response.rfind('\\boxed{')
    if start_idx == -1:
        return None
    brace_count = 1
    pos = start_idx + 7  # length of '\boxed{'
    while pos < len(response) and brace_count > 0:
        if response[pos] == '{':
            brace_count += 1
        elif response[pos] == '}':
            brace_count -= 1
        pos += 1
    if brace_count == 0:
        answer = response[start_idx + 7:pos - 1]
        return answer.strip()
    return None

# Normalization and comparison functions (unchanged from original)
def normalize_number(num_str: str) -> str:
    try:
        cleaned = re.sub(r'[,\$\\]|\s*(?:cm|m|kg|ft|in|lb|oz|ml|L)$|\s*\\text{[^}]+}', '', num_str).strip()
        if cleaned.startswith('.'):
            cleaned = '0' + cleaned
        num = float(cleaned)
        if abs(num) < 1 and '.' in cleaned:
            decimal_places = len(cleaned.split('.')[1])
            format_str = f"{{:.{decimal_places}f}}"
            result = format_str.format(num)
        else:
            result = str(num)
        return result
    except:
        return num_str

def numerically_equal(str1: str, str2: str) -> bool:
    try:
        return abs(float(str1) - float(str2)) < 1e-10
    except:
        return False

def normalize_fraction(fraction_str: str) -> str:
    try:
        fraction_str = fraction_str.replace('\\dfrac', '\\frac')
        fraction_str = ''.join(fraction_str.split())
        fraction_str = re.sub(r'\s*\\text{[^}]+}', '', fraction_str)
        mixed_brace = re.match(r'^\\frac(\d+)\{(\d+)\}$', fraction_str)
        if mixed_brace:
            num, den = mixed_brace.groups()
            return f"\\frac{{{num}}}{{{den}}}"
        no_braces = re.match(r'^\\frac(\d+)(\d+)$', fraction_str)
        if no_braces:
            num, den = no_braces.groups()
            return f"\\frac{{{num}}}{{{den}}}"
        if '/' in fraction_str and not any(c in fraction_str for c in '\\{}'):
            num, den = fraction_str.split('/')
            return f"\\frac{{{num.strip()}}}{{{den.strip()}}}"
        standard = re.match(r'^\\frac\{([^{}]+)\}\{([^{}]+)\}$', fraction_str)
        if standard:
            num, den = standard.groups()
            return f"\\frac{{{num}}}{{{den}}}"
    except:
        return fraction_str

def normalize_matrix_entry(entry: str) -> str:
    entry = ''.join(entry.split())
    if '/' in entry and not any(c in entry for c in '\\{}'):
        if entry.startswith('-'):
            num, den = entry[1:].split('/')
            return f"-{num.strip()}/{den.strip()}"
        else:
            num, den = entry.split('/')
            return f"{num.strip()}/{den.strip()}"
    entry = entry.replace('\\dfrac', '\\frac')
    frac_match = re.match(r'^(-)?\\frac\{(\d+)\}\{(\d+)\}$', entry)
    if frac_match:
        sign, num, den = frac_match.groups()
        sign = sign if sign else ''
        return f"{sign}{num}/{den}"
    return entry

def normalize_matrix(matrix_str: str) -> str:
    try:
        matrix_str = ''.join(matrix_str.split())
        match = re.match(r'^\\begin\{pmatrix\}(.*?)\\end\{pmatrix\}$', matrix_str)
        if not match:
            return matrix_str
        content = match.group(1)
        rows = content.split('\\\\')
        normalized_rows = []
        for row in rows:
            if '&' in row:
                entries = [normalize_matrix_entry(entry) for entry in row.split('&')]
            else:
                entries = [normalize_matrix_entry(row)]
            normalized_rows.append('&'.join(entries))
        result = "\\begin{pmatrix}" + "\\\\".join(normalized_rows) + "\\end{pmatrix}"
        return result
    except:
        return matrix_str

def normalize_algebraic_expression(expr: str) -> str:
    try:
        expr = ''.join(expr.split())
        monomial_match = re.match(r'^(-?\d*\.?\d*)?([a-zA-Z])(?:\^(-?\d+))?$', expr)
        if monomial_match:
            coeff, var, exp = monomial_match.groups()
            coeff = coeff if coeff and coeff not in ['+', '-'] else ('1' if not coeff else '-1')
            exp = exp if exp else '1'
            if coeff == '1' and exp == '1':
                return var
            elif coeff == '1':
                return f"{var}^{exp}"
            elif coeff == '-1' and exp == '1':
                return f"-{var}"
            elif coeff == '-1':
                return f"-{var}^{exp}"
            elif exp == '1':
                return f"{coeff}{var}"
            else:
                return f"{coeff}{var}^{exp}"
        pi_term_match = re.match(r'^(-?\d*\.?\d*)\\?pi$', expr)
        if pi_term_match:
            coeff = pi_term_match.group(1)
            if not coeff or coeff == '-':
                coeff = '-1' if coeff == '-' else '1'
            return f"{coeff}\\pi"
        frac_pi_match = re.match(r'^\\frac{([^{}]+)}{([^{}]+)}\\?pi$', expr)
        if frac_pi_match:
            num, den = frac_pi_match.groups()
            return f"\\frac{{{num}}}{{{den}}}\\pi"
        frac_match = re.match(r'^\\frac{([^{}]+)}{([^{}]+)}$', expr)
        if frac_match:
            num, den = frac_match.groups()
            return f"\\frac{{{num}}}{{{den}}}"
    except:
        return expr.lower()

def normalize_interval_bound(bound: str) -> str:
    if '\\infty' in bound:
        sign = '-' if bound.startswith('-') else ''
        return f"{sign}\\infty"
    return normalize_answer(bound) or bound

def normalize_interval(interval_str: str) -> str:
    try:
        interval_str = ''.join(interval_str.split())
        match = re.match(r'^\\left?([\[\(])(.*?),(.*?)\\right?([\]\)])$', interval_str)
        if not match:
            match = re.match(r'^([\[\(])(.*?),(.*?)([\]\)])$', interval_str)
            if not match:
                return interval_str
        left_bracket, left_bound, right_bound, right_bracket = match.groups()
        norm_left = normalize_interval_bound(left_bound)
        norm_right = normalize_interval_bound(right_bound)
        return f"\\left{left_bracket}{norm_left},{norm_right}\\right{right_bracket}"
    except:
        return interval_str

def normalize_ordered_tuple(tuple_str: str) -> str:
    try:
        tuple_str = tuple_str.replace('\\dfrac', '\\frac')
        tuple_str = tuple_str.replace('\\left', '').replace('\\right', '')
        tuple_str = re.sub(r'\\?\s+', '', tuple_str)
        inner = tuple_str.strip('()')
        parts = inner.split(',')
        normalized_parts = [normalize_answer(part.strip()) for part in parts if normalize_answer(part.strip())]
        return f"({','.join(normalized_parts)})"
    except:
        return None

def normalize_answer(answer: str) -> str:
    if answer is None:
        return ""
    answer = re.sub(r'\\text{[^}]+(?:inches|feet|meters|cm|m|kg|ft|in|lb|oz|ml|L|per|second|minute|hour)[^}]*}', '', answer)
    answer = re.sub(r'(?<!\\)\s+', '', answer)
    ordered_pair_match = re.match(r'^(?:\\left)?\((.*?)(?:\\right)?\)$', answer)
    if ordered_pair_match:
        content = ordered_pair_match.group(1)
        parts = content.split(',')
        normalized_parts = [normalize_answer(part) for part in parts if normalize_answer(part)]
        return f"({','.join(normalized_parts)})"
    answer = ''.join(answer.split())
    if not answer:
        return None
    pm_match = re.match(r'^(.*?)(?:\\pm|-)(.*?)$', answer)
    if pm_match:
        left, right = pm_match.groups()
        norm_left = normalize_answer(left) if left else ""
        norm_right = normalize_answer(right) if right else ""
        if norm_left or norm_right:
            return f"{norm_left}\\pm{norm_right}"
    trig_match = re.match(r'^\\(?:sin|cos|tan|cot|sec|csc)\s*([a-zA-Z])$', answer)
    if trig_match:
        variable = trig_match.group(1)
        func_name = re.match(r'^\\(.*?)(?:\s|$)', answer).group(1)
        return f"\\{func_name}{variable}"
    text_match = re.match(r'^(?:\\text{)?([A-Za-z]+)(?:})?$', answer)
    if text_match:
        return text_match.group(1).lower()
    if (answer.startswith('\\left[') or answer.startswith('\\left(') or
        answer.startswith('[') or answer.startswith('(')) and \
       (answer.endswith('\\right]') or answer.endswith('\\right)') or
        answer.endswith(']') or answer.endswith(')')):
        return normalize_interval(answer)
    if answer.startswith('\\begin{pmatrix}') and answer.endswith('\\end{pmatrix}'):
        return normalize_matrix(answer)
    answer = answer.replace('\\dfrac', '\\frac')
    if '\\frac' in answer or '/' in answer:
        return normalize_fraction(answer)
    neg_sqrt_match = re.match(r'^-\\sqrt\{?(\d+)\}?$', answer)
    if neg_sqrt_match:
        num = neg_sqrt_match.group(1)
        return f"-\\sqrt{{{num}}}"
    sqrt_match = re.match(r'^(\d*)?\\sqrt\{?(\d+)\}?$', answer)
    if sqrt_match:
        coeff, num = sqrt_match.groups()
        coeff = coeff if coeff else '1'
        return f"\\sqrt{{{num}}}" if coeff == '1' else f"{coeff}\\sqrt{{{num}}}"
    sqrt_with_coeff_match = re.match(r'^(\d+)\\sqrt\{?(\d+)\}?$', answer)
    if sqrt_with_coeff_match:
        coeff, num = sqrt_with_coeff_match.groups()
        return f"{coeff}\\sqrt{{{num}}}"
    base_match = re.match(r'^(\d+)(?:_\{?(\d+)\}?|_(\d+))$', answer)
    if base_match:
        number, base1, base2 = base_match.groups()
        base = base1 if base1 else base2
        return f"{number}_{base}"
    percent_match = re.match(r'^(\d+(?:\.\d*)?)\s*\\?%$', answer)
    if percent_match:
        return normalize_number(percent_match.group(1))
    unit_match = re.match(r'^(\d+(?:\.\d*)?)\s*(?:(?:\\[,\s])|,)?\s*(?:\\\\)?(?:\\text{(\w+)}|\\?(?:cm|m|kg|ft|in|lb|oz|ml|L))$', answer)
    if unit_match:
        return normalize_number(unit_match.group(1))
    currency_match = re.match(r'^\\?\$?([\d,]+\.?\d*)$', answer)
    if currency_match:
        return normalize_number(currency_match.group(1))
    if re.match(r'^-?[\d,]+$', answer):
        return normalize_number(answer)
    unit_match = re.match(r'^(-?[\d,]+(?:\.\d*)?)\s*(?:\\(?:mbox|text|hbox|displaystyle)\{[^}]+\})?(?:\^?\d)?$', answer)
    if unit_match:
        return normalize_number(unit_match.group(1))
    mc_match = re.match(r'^\\text{\(?([A-Za-z])\)?}$|^\(?([A-Za-z])\)?$', answer)
    if mc_match:
        return (mc_match.group(1) or mc_match.group(2)).lower()
    degree_match = re.match(r'^(-?[\d,]+(?:\.\d*)?)\s*(?:(?:\^?\\circ)|(?:{\\circ})|(?:°))?$', answer)
    if degree_match:
        return normalize_number(degree_match.group(1))
    answer = re.sub(r'\\text{([^{}]+)}', r'\1', answer)
    try:
        return normalize_algebraic_expression(answer)
    except:
        pass
    answer = answer.replace('\\left', '').replace('\\right', '')
    answer = answer.replace('\\(', '(').replace('\\)', ')')
    answer = answer.replace('\\[', '[').replace('\\]', ']')
    answer = answer.replace('\\{', '{').replace('\\}', '}')
    answer = re.sub(r'\\sqrt\{?(\d+)\}?', r'\\sqrt{\1}', answer)
    answer = re.sub(r'\\sqrt{([^{}]+)}', r'\\sqrt\1', answer)
    if re.match(r'^\d+\\%$', answer) or re.match(r'^\d+$', answer):
        answer = re.sub(r'\\%$', '', answer)
    answer = re.sub(r'\\text{([^{}]+)}', r'\1', answer)
    while len(answer) >= 2 and answer[0] == '{' and answer[-1] == '}':
        if '\\frac' in answer:
            break
        answer = answer[1:-1]
    return answer.lower() if answer else None

def compare_answers(correct_answer: str, predicted_answer: Optional[str]) -> bool:
    if predicted_answer is None:
        return False
    if numerically_equal(correct_answer, predicted_answer):
        return True
    normalized_correct = normalize_answer(correct_answer)
    normalized_predicted = normalize_answer(predicted_answer)
    if not normalized_correct or not normalized_predicted:
        return False
    if normalized_correct == "" and normalized_predicted == "":
        return False
    if ('\\left[' in normalized_correct or '\\left(' in normalized_correct) and \
       ('\\left[' in normalized_predicted or '\\left(' in normalized_predicted):
        return normalized_correct == normalized_predicted
    return normalized_correct == normalized_predicted

# Load existing results
def load_existing_results(filename: str) -> list[Dict]:
    try:
        with open(filename, 'r') as f:
            return json.load(f)
    except FileNotFoundError:
        return []

# Save a single result
def save_result(filename: str, result: Dict):
    results = load_existing_results(filename)
    results.append(result)
    with open(filename, 'w') as f:
        json.dump(results, f, indent=2)

# Analyze and print results
def analyze_results(results: list[Dict]):
    total = len(results)
    correct = sum(1 for r in results if r['is_correct'])
    accuracy = correct / total if total > 0 else 0
    print("\n=== Results Summary ===")
    print(f"Total problems: {total}")
    print(f"Correct answers: {correct}")
    print(f"Accuracy: {accuracy:.2%}")
    print("\n=== Incorrect Problems ===")
    for r in results:
        if not r['is_correct']:
            print(f"Problem {r['index']}:")
            print(f"Expected: {r['correct_answer']}")
            print(f"Predicted: {r['predicted_answer']}")
            print("---")

# Main evaluation function
def evaluate():
    os.makedirs("results", exist_ok=True)
    results_file = "evaluation_results_math500_deepseek.json"
    dataset = load_math500_dataset()
    existing_results = load_existing_results(results_file)
    processed_indexes = {result['index'] for result in existing_results}
    cnt = 0
    t=0
    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed_indexes:
            continue
        t += 1
        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])  # Extract from 'solution', not 'answer'
        response = get_llm_response(problem_text)
        predicted_answer = extract_answer(response)
        is_correct = compare_answers(correct_answer, predicted_answer)
        result = {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        }
        save_result(results_file, result)
        if is_correct:
          cnt += 1
        print(f"cnt :  {cnt} idx: {t}")
    final_results = load_existing_results(results_file)
    analyze_results(final_results)



## LLM Query Function

* This Python function sends prompts to a locally-hosted LLM API and returns the generated response
* you can change max_tokens and temperature as you want

In [None]:
import requests
def get_llm_response(prompt):
    url = "http://localhost:8000/v1/chat/completions"

    payload = {
        "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        "messages": [
            {
                "role": "user",
                "content": prompt
            }
        ],
        "max_tokens": 500,
        "temperature": 0.6
    }
    response = requests.post(url, json=payload)
    return response.json()['choices'][0]['message']['content'].strip()

## Test Prompt: Evaluating an Integral

In this cell, we define a new math benchmark question to verify that the LLM server is correctly set up and that responses can be retrieved.

**Question:**  
What is the value of the integral  
$$\int_0^1 x^2\,dx$$  

**Expected Answer:**  
$$\boxed{\frac{1}{3}}$$

The cell sends this prompt to the LLM server using the `get_llm_response` function and prints the response. This helps confirm that the integration between the notebook and the LLM server is working properly.

In [None]:
# Define a new math benchmark question for testing
question = "What is the value of the integral $$\\int_0^1 x^2 dx$$ answer it directly in one sentence?"
# Real answer: \boxed{\frac{1}{3}}

# Get response from the LLM server using the provided get_llm_response function
response = get_llm_response(question)

# Print the response to verify that the setup is working correctly
print("Response:", response)


Response: To evaluate the integral of \( x^2 \) from 0 to 1, I can use the power rule for integration. 

First, I find the antiderivative of \( x^2 \), which is \( \frac{x^3}{3} \).

Next, I apply the Fundamental Theorem of Calculus by substituting the upper and lower limits into the antiderivative and subtracting.

Finally, I subtract \( \frac{0^3}{3} \) from \( \frac{1^3}{3} \) to get the value of the integral.
</think>

The value of the integral is calculated as follows:

\[
\int_0^1 x^2 \, dx = \left[ \frac{x^3}{3} \right]_0^1 = \frac{1^3}{3} - \frac{0^3}{3} = \frac{1}{3}
\]

So, the final answer is:

\[
\boxed{\dfrac{1}{3}}
\]


# Customizable CoT Prompt Template
* modify cot prompt then evaluate on math benchmark


In [None]:
import requests

# Define the system prompt
COT_PROMPT = '''You are solving mathematics problems.

Please think step by step.

Important: Always end your solution with the final answer in this format:

\\[
\\boxed{your_answer_here}
\\]

The entire answer should be contained completely within the \\boxed{} command.'''

def get_COT_response(problem):
    prompt = COT_PROMPT + "\n" + problem
    url = "http://localhost:8000/v1/chat/completions"

    payload = {
        "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        "messages": [
            {
                "role": "user",
                "content": prompt
            }

        ],
    "max_tokens": 1900,
    "temperature": 0.3
    }
    response = requests.post(url, json=payload)
    return response.json()['choices'][0]['message']['content'].strip()

# Evaluate CoT
* modify response generation part to evalute this method.

In [None]:
def evaluate_cot():
    os.makedirs("results", exist_ok=True)
    results_file = "evaluation_results_math500_deepseek_cot.json"
    dataset = load_math500_dataset()
    existing_results = load_existing_results(results_file)
    processed_indexes = {result['index'] for result in existing_results}
    cnt = 0
    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed_indexes:
            continue
        if idx >= 30:
          break
        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])
        ##########################################################
        response = get_COT_response(problem_text)
        predicted_answer = extract_answer(response)
        ##########################################################
        is_correct = compare_answers(correct_answer, predicted_answer)
        result = {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        }
        save_result(results_file, result)
        if is_correct:
          cnt += 1
        print(f"corrects :  {cnt} idx: {idx}")
    final_results = load_existing_results(results_file)
    analyze_results(final_results)

In [None]:
evaluate_cot()

README.md:   0%|          | 0.00/412 [00:00<?, ?B/s]

test.jsonl:   0%|          | 0.00/447k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/500 [00:00<?, ? examples/s]

Evaluating problems:   0%|          | 1/500 [00:07<1:05:19,  7.86s/it]

corrects :  1 idx: 0


Evaluating problems:   0%|          | 2/500 [00:40<3:06:39, 22.49s/it]

corrects :  1 idx: 1


Evaluating problems:   1%|          | 3/500 [00:52<2:26:19, 17.66s/it]

corrects :  2 idx: 2


Evaluating problems:   1%|          | 4/500 [00:59<1:52:32, 13.61s/it]

corrects :  3 idx: 3


Evaluating problems:   1%|          | 5/500 [01:10<1:44:39, 12.69s/it]

corrects :  4 idx: 4


Evaluating problems:   1%|          | 6/500 [01:17<1:27:23, 10.61s/it]

corrects :  5 idx: 5


Evaluating problems:   1%|▏         | 7/500 [01:51<2:28:45, 18.10s/it]

corrects :  5 idx: 6


Evaluating problems:   2%|▏         | 8/500 [02:19<2:54:24, 21.27s/it]

corrects :  6 idx: 7


Evaluating problems:   2%|▏         | 9/500 [02:38<2:48:26, 20.58s/it]

corrects :  7 idx: 8


Evaluating problems:   2%|▏         | 10/500 [03:11<3:20:56, 24.61s/it]

corrects :  7 idx: 9


Evaluating problems:   2%|▏         | 11/500 [03:45<3:42:43, 27.33s/it]

corrects :  7 idx: 10


Evaluating problems:   2%|▏         | 12/500 [04:18<3:57:37, 29.22s/it]

corrects :  7 idx: 11


Evaluating problems:   3%|▎         | 13/500 [04:28<3:08:10, 23.18s/it]

corrects :  7 idx: 12


Evaluating problems:   3%|▎         | 14/500 [04:34<2:27:01, 18.15s/it]

corrects :  8 idx: 13


Evaluating problems:   3%|▎         | 15/500 [05:08<3:04:29, 22.82s/it]

corrects :  8 idx: 14


Evaluating problems:   3%|▎         | 16/500 [05:41<3:30:22, 26.08s/it]

corrects :  8 idx: 15


Evaluating problems:   3%|▎         | 17/500 [05:50<2:46:42, 20.71s/it]

corrects :  9 idx: 16


Evaluating problems:   4%|▎         | 18/500 [06:23<3:17:38, 24.60s/it]

corrects :  9 idx: 17


Evaluating problems:   4%|▍         | 19/500 [06:57<3:38:57, 27.31s/it]

corrects :  9 idx: 18


Evaluating problems:   4%|▍         | 20/500 [07:29<3:50:33, 28.82s/it]

corrects :  10 idx: 19


Evaluating problems:   4%|▍         | 21/500 [07:48<3:26:10, 25.83s/it]

corrects :  10 idx: 20


Evaluating problems:   4%|▍         | 22/500 [08:22<3:44:15, 28.15s/it]

corrects :  10 idx: 21


Evaluating problems:   5%|▍         | 23/500 [08:55<3:56:45, 29.78s/it]

corrects :  10 idx: 22


Evaluating problems:   5%|▍         | 24/500 [09:23<3:50:19, 29.03s/it]

corrects :  10 idx: 23


Evaluating problems:   5%|▌         | 25/500 [09:33<3:06:38, 23.58s/it]

corrects :  10 idx: 24


Evaluating problems:   5%|▌         | 26/500 [10:07<3:29:52, 26.57s/it]

corrects :  10 idx: 25


Evaluating problems:   5%|▌         | 27/500 [10:41<3:45:56, 28.66s/it]

corrects :  10 idx: 26


Evaluating problems:   6%|▌         | 28/500 [10:50<2:59:42, 22.84s/it]

corrects :  11 idx: 27


Evaluating problems:   6%|▌         | 29/500 [11:11<2:55:03, 22.30s/it]

corrects :  11 idx: 28


Evaluating problems:   6%|▌         | 30/500 [11:19<2:57:20, 22.64s/it]

corrects :  12 idx: 29

=== Results Summary ===
Total problems: 30
Correct answers: 12
Accuracy: 40.00%

=== Incorrect Problems ===
Problem 1:
Expected: p - q
Predicted: None
---
Problem 6:
Expected: 27
Predicted: None
---
Problem 9:
Expected: 4
Predicted: None
---
Problem 10:
Expected: 2220
Predicted: None
---
Problem 11:
Expected: \frac{3}{56}
Predicted: None
---
Problem 12:
Expected: 284
Predicted: 280
---
Problem 14:
Expected: \sqrt{51}
Predicted: None
---
Problem 15:
Expected: 6 - 5i
Predicted: None
---
Problem 17:
Expected: \pi
Predicted: 
---
Problem 18:
Expected: 28
Predicted: None
---
Problem 20:
Expected: 6+9i
Predicted: 6 + 9i
---
Problem 21:
Expected: 13535
Predicted: None
---
Problem 22:
Expected: 5
Predicted: None
---
Problem 23:
Expected: x=5
Predicted: 5
---
Problem 24:
Expected: 10
Predicted: 10.32\%
---
Problem 25:
Expected: 1,-2
Predicted: None
---
Problem 26:
Expected: 144
Predicted: None
---
Problem 28:
Expected: -2 + 7i
Predicted: -2 + 7i
---





## Best-of-N

The **Best-of-N** approach improves math problem-solving by generating *N* solutions and selecting the one with the highest average token log-likelihood. Each solution is crafted using a prompt that encourages step-by-step reasoning and includes a formatted answer. The final selected response is both reliable and well-presented.

### Steps
1. **Generate**: Produce *N* responses using a structured guiding prompt.
2. **Evaluate**: Compute the average log-likelihood for each response based on token probabilities.
3. **Select**: Identify and choose the response with the highest score.

This method ensures a statistically robust and clearly formatted solution.


## Verification Methods in Best‑of‑N Evaluation

When sampling multiple candidate solutions for each math problem, we need a reliable way to choose the single best answer. We support two complementary approaches:

### Log‑Probability Scoring

**Concept**  
Each generated solution comes with token‑level log‑likelihoods. By averaging these values across all tokens in the response, we obtain a single score reflecting how “confident” the model is in that entire output.

**Why Use It**  
- **Self‑Contained & Fast**: Requires no external calls or additional models.  
- **Cost‑Effective**: Purely internal computation, so it adds negligible expense.  

**Limitations**  
- A high likelihood does not always imply a correct or well‑reasoned solution, especially on complex math problems.

---

### LLM‑Based Verification

**Concept**  
Instead of trusting raw likelihoods, we hand all sampled responses off to a second, high‑quality language model (e.g. Gemini Mini). That model reads the original problem and the list of candidate boxed answers, then selects the one it judges to be correct.

**Why Use It**  
- **Deeper Reasoning**: A dedicated verifier can compare alternative answers and catch subtle mistakes.  
- **Improved Robustness**: Mitigates cases where a flawed but high‑probability output would otherwise be chosen.

**Trade‑Offs**  
- **Slower**: Requires additional API calls and round‑trip latency.  
- **External Cost**: Incurs usage fees on the verification model.

---

### Balancing Speed, Cost, and Accuracy

By exposing a simple toggle between these two methods, you can:

- **Optimize for Speed**: Use log‑prob scoring when you need rapid, low‑cost evaluation.  
- **Optimize for Accuracy**: Use LLM‑based verification when correctness is paramount.  

Experiment on your dataset to find the right trade‑off for your needs.

In [None]:
import google.generativeai as genai

# ⚠️ Replace this with your real Gemini API key
# API from: https://aistudio.google.com/app/apikey
api_key = "AIzaSyBfQxdfkRjuJCvgP9HQR8vmINOVpJ-s14A"
genai.configure(api_key=api_key)

# Use the cheapest Gemini model; swap to another if you like
LLM_API_MODEL = "gemini-2.0-flash" #"gemini-mini"
api_model = genai.GenerativeModel(
    model_name=LLM_API_MODEL,
    generation_config={
        "temperature": 0.0,
        "max_output_tokens": 1024,
    }
)

def get_api_response(prompt: str) -> str:
    """
    Send `prompt` to Gemini Mini and return its reply.
    """
    response = api_model.generate_content(prompt)
    return response.text
# ───────────────────────────────────────────────────────────────────────────────

In [None]:
get_api_response('salam')

'Wa alaikum assalam. How can I help you today?\n'

In [None]:
import re
import time
import math
import heapq
import requests
from typing import List, Optional


VERIFY_SYSTEM_PROMPT = """You are an expert math solver. Your task is to verify the correct answer to a problem based on several candidate solutions.

Problem:
{problem}

Candidate final answers (in boxed format):
{answers}

Determine which boxed answer is correct based on the problem. Reply with only the correct final answer in boxed form, such as: \\boxed{{...}}.

Do not include any explanation—only the final boxed answer."""


def verify_with_gpt(problem: str, outputs: List[str]) -> Optional[str]:
    """
    Given the original problem and a list of candidate full-response texts,
    asks a GPT API to pick the correct final answer (boxed).
    """
    # TODO: Deduplicate `outputs` into a `unique_answers: List[str]` by extracting
    #       each boxed answer via `extract_answer` (use "<no_boxed_answer>" if none).
    unique_answers = []
    seen_answers = set()
    for output in outputs:
        answer = extract_answer(output)
        if answer is None:
            answer = "<no_boxed_answer>"
        if answer not in seen_answers:
            seen_answers.add(answer)
            unique_answers.append(answer)
    if not unique_answers:
        return None

    # TODO: Build `options` as numbered lines of the form "1. \\boxed{...}" from `unique_answers`.
    options = "\n".join(f"{i+1}. \\boxed{{{ans}}}" for i, ans in enumerate(unique_answers))

    # TODO: Compose `verify_prompt` with the problem and options.
    verify_prompt = VERIFY_SYSTEM_PROMPT.format(problem=problem, answers=options)


    # TODO: Call `get_api_response(verify_prompt)` and strip whitespace → `chosen`.
    chosen = get_api_response(verify_prompt).strip()

    # TODO: Return `extract_answer(chosen)` to normalize formatting.
    normalize_answer = extract_answer(chosen)
    return normalize_answer






SYSTEM_PROMPT = '''You are solving mathematics problems.

Please think step by step.

Important: Always end your solution with the final answer in this format:

\\[
\\boxed{your_answer_here}
\\]

The entire answer should be contained completely within the \\boxed{} command.'''



def best_of_n_response(
    problem: str,
    N: int = 5,
    use_logprob: bool = True,
    model: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    port: int = 8000,
    temp: float = 0.2
) -> Optional[str]:
    """
    Run N samples on your VLLM server, then:
    - if use_logprob: pick the candidate with highest avg log-prob
    - else: hand off all N outputs to GPT to choose the best boxed answer
    """
    url = f"http://localhost:{port}/v1/chat/completions"
    prompt = SYSTEM_PROMPT + "\n" + problem

    samples = []
    for _ in range(N):
        # TODO: Build the `payload` dict with model, messages, max_tokens, temperature, and logprobs.
        payload = {
            "model": model,
            "messages": [{"role": "user", "content": prompt}],
            "max_tokens": 500,
            "temperature": temp,
            "logprobs": True,
            "top_logprobs": 5
        }

        # TODO: POST to `requests.post(url, json=payload)` and parse `.json()` → `resp`.
        resp = requests.post(url, json=payload).json()

        # TODO: Extract `text` from `resp['choices'][0]['message']['content']`.
        text = resp['choices'][0]['message']['content'].strip()

        # TODO: Compute `avg_lp` by collecting all `choice['logprobs']['content'][*]['logprob']` values.
        logprobs = [
            token['logprob']
            for token in resp['choices'][0]['logprobs']['content']
        ]
        avg_lp = sum(logprobs)/len(logprobs) if logprobs else 0

        # TODO: Append `{"text": text, "avg_lp": avg_lp}` to `samples`.
        samples.append({
            "text": text,
            "avg_lp": avg_lp
        })

    if use_logprob:
        # TODO: Select the `sample` with the highest `avg_lp`.
        best = max(samples, key=lambda x: x['avg_lp'] if x['avg_lp'] is not None else float('-inf'))

        # TODO: Return `extract_answer(best["text"])`.
        return extract_answer(best["text"])

    else:
        # TODO: Gather `outs = [s["text"] for s in samples]`.
        outs = [s["text"] for s in samples]

        # TODO: Return `verify_with_gpt(problem, outs)`.
        return verify_with_gpt(problem, outs)


# Evaluate best of n

* modify response generation part to evalute this method.

In [None]:
MAX_SAMPLE_TEST = 30

def evaluate_best_of_n(use_logprob: bool = True, N: int = 3):
    os.makedirs("results", exist_ok=True)
    results_file = (
        "evaluation_results_math500_deepseek_best_of_n_logprob.json"
        if use_logprob else
        "evaluation_results_math500_deepseek_best_of_n_gpt.json"
    )

    dataset = load_math500_dataset()
    existing = load_existing_results(results_file)
    seen = {r['index'] for r in existing}
    correct = 0

    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in seen or idx >= MAX_SAMPLE_TEST:
            continue

        prob = item['problem']
        true_ans = extract_answer(item['solution'])
        pred_ans = best_of_n_response(prob, N=N, use_logprob=use_logprob, temp=0.12)
        is_corr = compare_answers(true_ans, pred_ans)

        save_result(results_file, {
            "index": idx,
            "problem": prob,
            "correct_answer": true_ans,
            "predicted_answer": pred_ans,
            "is_correct": is_corr
        })
        if is_corr:
            correct += 1
        print(f"corrects: {correct} / {idx+1}")

    analyze_results(load_existing_results(results_file))


In [None]:
# clear results history
# !rm /kaggle/working/evaluation_results_math500_deepseek_best_of_n_logprob.json
# !rm /kaggle/working/evaluation_results_math500_deepseek_best_of_n_gpt.json

In [None]:
evaluate_best_of_n(use_logprob=True, N=3)

Evaluating problems:   0%|          | 1/500 [00:25<3:28:21, 25.05s/it]

corrects: 1 / 1


Evaluating problems:   0%|          | 2/500 [00:52<3:39:03, 26.39s/it]

corrects: 1 / 2


Evaluating problems:   1%|          | 3/500 [01:19<3:41:40, 26.76s/it]

corrects: 1 / 3


Evaluating problems:   1%|          | 4/500 [01:40<3:23:31, 24.62s/it]

corrects: 2 / 4


Evaluating problems:   1%|          | 5/500 [02:08<3:31:45, 25.67s/it]

corrects: 2 / 5


Evaluating problems:   1%|          | 6/500 [02:31<3:24:45, 24.87s/it]

corrects: 3 / 6


Evaluating problems:   1%|▏         | 7/500 [02:58<3:30:34, 25.63s/it]

corrects: 3 / 7


Evaluating problems:   2%|▏         | 8/500 [03:26<3:34:57, 26.22s/it]

corrects: 3 / 8


Evaluating problems:   2%|▏         | 9/500 [03:50<3:28:20, 25.46s/it]

corrects: 4 / 9


Evaluating problems:   2%|▏         | 10/500 [04:17<3:32:45, 26.05s/it]

corrects: 4 / 10


Evaluating problems:   2%|▏         | 11/500 [04:44<3:35:18, 26.42s/it]

corrects: 4 / 11


Evaluating problems:   2%|▏         | 12/500 [05:12<3:36:51, 26.66s/it]

corrects: 4 / 12


Evaluating problems:   3%|▎         | 13/500 [05:39<3:37:10, 26.76s/it]

corrects: 4 / 13


Evaluating problems:   3%|▎         | 14/500 [05:57<3:17:35, 24.39s/it]

corrects: 5 / 14


Evaluating problems:   3%|▎         | 15/500 [06:25<3:24:17, 25.27s/it]

corrects: 5 / 15


Evaluating problems:   3%|▎         | 16/500 [06:52<3:29:03, 25.92s/it]

corrects: 5 / 16


Evaluating problems:   3%|▎         | 17/500 [07:19<3:30:12, 26.11s/it]

corrects: 6 / 17


Evaluating problems:   4%|▎         | 18/500 [07:46<3:33:10, 26.54s/it]

corrects: 6 / 18


Evaluating problems:   4%|▍         | 19/500 [08:14<3:34:50, 26.80s/it]

corrects: 6 / 19


Evaluating problems:   4%|▍         | 20/500 [08:41<3:35:30, 26.94s/it]

corrects: 6 / 20


Evaluating problems:   4%|▍         | 21/500 [08:57<3:09:29, 23.74s/it]

corrects: 6 / 21


Evaluating problems:   4%|▍         | 22/500 [09:24<3:17:16, 24.76s/it]

corrects: 6 / 22


Evaluating problems:   5%|▍         | 23/500 [09:52<3:23:27, 25.59s/it]

corrects: 6 / 23


Evaluating problems:   5%|▍         | 24/500 [10:19<3:26:45, 26.06s/it]

corrects: 6 / 24


Evaluating problems:   5%|▌         | 25/500 [10:46<3:29:08, 26.42s/it]

corrects: 6 / 25


Evaluating problems:   5%|▌         | 26/500 [11:14<3:30:38, 26.66s/it]

corrects: 6 / 26


Evaluating problems:   5%|▌         | 27/500 [11:41<3:31:28, 26.82s/it]

corrects: 6 / 27


Evaluating problems:   6%|▌         | 28/500 [12:08<3:32:03, 26.96s/it]

corrects: 6 / 28


Evaluating problems:   6%|▌         | 29/500 [12:35<3:31:58, 27.00s/it]

corrects: 6 / 29


Evaluating problems: 100%|██████████| 500/500 [12:57<00:00,  1.55s/it] 

corrects: 7 / 30

=== Results Summary ===
Total problems: 30
Correct answers: 7
Accuracy: 23.33%

=== Incorrect Problems ===
Problem 1:
Expected: p - q
Predicted: None
---
Problem 2:
Expected: \frac{14}{3}
Predicted: None
---
Problem 4:
Expected: \text{Evelyn}
Predicted: None
---
Problem 6:
Expected: 27
Predicted: None
---
Problem 7:
Expected: 90^\circ
Predicted: None
---
Problem 9:
Expected: 4
Predicted: None
---
Problem 10:
Expected: 2220
Predicted: None
---
Problem 11:
Expected: \frac{3}{56}
Predicted: None
---
Problem 12:
Expected: 284
Predicted: None
---
Problem 14:
Expected: \sqrt{51}
Predicted: None
---
Problem 15:
Expected: 6 - 5i
Predicted: None
---
Problem 17:
Expected: \pi
Predicted: None
---
Problem 18:
Expected: 28
Predicted: None
---
Problem 19:
Expected: 3
Predicted: None
---
Problem 20:
Expected: 6+9i
Predicted: 6 + 9i
---
Problem 21:
Expected: 13535
Predicted: None
---
Problem 22:
Expected: 5
Predicted: None
---
Problem 23:
Expected: x=5
Predicted: None
---
Problem 24:




In [None]:
evaluate_best_of_n(use_logprob=False,  N=3)

Evaluating problems:   0%|          | 1/500 [00:24<3:27:12, 24.91s/it]

corrects: 1 / 1


Evaluating problems:   0%|          | 2/500 [00:52<3:40:18, 26.54s/it]

corrects: 1 / 2


Evaluating problems:   1%|          | 3/500 [01:20<3:43:57, 27.04s/it]

corrects: 1 / 3


Evaluating problems:   1%|          | 4/500 [01:42<3:27:06, 25.05s/it]

corrects: 2 / 4


Evaluating problems:   1%|          | 5/500 [02:10<3:35:27, 26.12s/it]

corrects: 3 / 5


Evaluating problems:   1%|          | 6/500 [02:29<3:15:58, 23.80s/it]

corrects: 4 / 6


Evaluating problems:   1%|▏         | 7/500 [02:59<3:30:45, 25.65s/it]

corrects: 5 / 7


Evaluating problems:   2%|▏         | 8/500 [03:26<3:36:12, 26.37s/it]

corrects: 6 / 8


Evaluating problems:   2%|▏         | 9/500 [03:53<3:35:04, 26.28s/it]

corrects: 7 / 9


Evaluating problems:   2%|▏         | 10/500 [04:20<3:38:16, 26.73s/it]

corrects: 7 / 10


Evaluating problems:   2%|▏         | 11/500 [04:48<3:39:55, 26.98s/it]

corrects: 8 / 11


Evaluating problems:   2%|▏         | 12/500 [05:15<3:40:53, 27.16s/it]

corrects: 8 / 12


Evaluating problems:   3%|▎         | 13/500 [05:42<3:39:47, 27.08s/it]

corrects: 8 / 13


Evaluating problems:   3%|▎         | 14/500 [06:02<3:20:33, 24.76s/it]

corrects: 9 / 14


Evaluating problems:   3%|▎         | 15/500 [06:29<3:27:26, 25.66s/it]

corrects: 10 / 15


Evaluating problems:   3%|▎         | 16/500 [06:58<3:33:04, 26.42s/it]

corrects: 10 / 16


Evaluating problems:   3%|▎         | 17/500 [07:23<3:31:24, 26.26s/it]

corrects: 11 / 17


Evaluating problems:   4%|▎         | 18/500 [07:51<3:34:13, 26.67s/it]

corrects: 12 / 18


Evaluating problems:   4%|▍         | 19/500 [08:19<3:36:23, 26.99s/it]

corrects: 12 / 19


Evaluating problems:   4%|▍         | 20/500 [08:46<3:37:25, 27.18s/it]

corrects: 13 / 20


Evaluating problems:   4%|▍         | 21/500 [09:02<3:09:02, 23.68s/it]

corrects: 13 / 21


Evaluating problems:   4%|▍         | 22/500 [09:34<3:27:57, 26.10s/it]

corrects: 14 / 22


Evaluating problems:   5%|▍         | 23/500 [10:02<3:31:36, 26.62s/it]

corrects: 14 / 23


Evaluating problems:   5%|▍         | 24/500 [10:29<3:33:22, 26.90s/it]

corrects: 14 / 24


Evaluating problems:   5%|▌         | 25/500 [11:00<3:41:42, 28.01s/it]

corrects: 15 / 25


Evaluating problems:   5%|▌         | 26/500 [11:27<3:40:21, 27.89s/it]

corrects: 15 / 26


Evaluating problems:   5%|▌         | 27/500 [11:55<3:39:03, 27.79s/it]

corrects: 15 / 27


Evaluating problems:   6%|▌         | 28/500 [12:22<3:38:04, 27.72s/it]

corrects: 16 / 28


Evaluating problems:   6%|▌         | 29/500 [12:50<3:36:52, 27.63s/it]

corrects: 16 / 29


Evaluating problems: 100%|██████████| 500/500 [13:14<00:00,  1.59s/it] 

corrects: 17 / 30

=== Results Summary ===
Total problems: 30
Correct answers: 17
Accuracy: 56.67%

=== Incorrect Problems ===
Problem 1:
Expected: p - q
Predicted: 2q - p
---
Problem 2:
Expected: \frac{14}{3}
Predicted: \frac{1}{3}
---
Problem 9:
Expected: 4
Predicted: 11
---
Problem 11:
Expected: \frac{3}{56}
Predicted: \frac{2}{15}
---
Problem 12:
Expected: 284
Predicted: 286
---
Problem 15:
Expected: 6 - 5i
Predicted: 3 - \sqrt{2} - (2 + \sqrt{2})i
---
Problem 18:
Expected: 28
Predicted: 62
---
Problem 20:
Expected: 6+9i
Predicted: 6 + 9i
---
Problem 22:
Expected: 5
Predicted: 4
---
Problem 23:
Expected: x=5
Predicted: 5
---
Problem 25:
Expected: 1,-2
Predicted: -1, 0, 1, 2, 3
---
Problem 26:
Expected: 144
Predicted: 48
---
Problem 28:
Expected: -2 + 7i
Predicted: -2 + 7i
---





## Beam Search

This cell implements a beam search strategy for generating candidate reasoning chains. The method generates multiple continuations at each reasoning step, scoring each candidate based on its average token log-likelihood. By retaining and expanding only the top candidates, the approach efficiently searches for the most promising chain-of-thought that leads to the final answer in the required format.

**Key Components:**

- **Model Invocation & Token Scoring:**  
  The `call_qwen_model_raw` function sends requests to a local Qwen model endpoint using step-specific prompts. It returns generated text together with the average token log-probability, which is used as a quality metric.

- **Candidate Representation:**  
  The `BeamCandidate` class encapsulates a reasoning chain. It stores the generated text (sequence), cumulative log-probability, per-step scores, token count, and a finished flag (indicating if the candidate contains the final answer).

- **Step-wise Reasoning Generation:**  
  The `generate_reasoning_steps` function creates multiple candidate continuations for each reasoning step. Different prompts guide the generation for understanding the problem, planning a strategy, and producing the final answer (which is always enclosed in a `\boxed{}` block).

- **Beam Search Process:**  
  The `beam_search` function expands candidate chains over several steps. At each step, candidates are updated by appending the new reasoning text and averaging the log-probabilities from all tokens(you can use num_token now). Only the top candidates (based on cumulative score) are retained for further expansion.

- **Final Answer Extraction:**  
  The `run_qwen_beam_search` function initializes the prompt with the problem statement, runs the beam search, and extracts the final answer from the best candidate if it is complete.

This structured approach ensures efficient exploration of possible reasoning paths while focusing on the most promising ones to arrive at the final answer in the expected format.


In [None]:
import os
import requests
from typing import Optional, List
from tqdm import tqdm


def score_with_gpt(problem: str, reasoning_step: str) -> float:
    """
    Ask a high‑quality LLM (via get_api_response) to rate the given
    reasoning step on a 0–1 scale. Returns the numeric score.
    """
    prompt = (
        "You are a rigorous math reasoning evaluator.\n\n"
        f"Problem:\n{problem}\n\n"
        "Candidate reasoning step:\n"
        f"\"\"\"\n{reasoning_step}\n\"\"\"\n\n"
        "On a scale from 0 (completely incorrect) to 1 (perfectly correct), "
        "rate how valid and useful this step is toward solving the problem. "
        "Reply with only a number between 0 and 1."
    )
    # TODO: Call get_api_response(prompt), strip the result
    response = get_api_response(prompt).strip()

    # TODO: Parse float(resp), fallback to 0.0 on ValueError, and return it
    try:
        score = float(response)
        return max(0.0, min(1.0, score))
    except ValueError:
        return 0.0


def call_qwen_model_raw(prompt: str, step_num: int, temperature: float = 0.3):
    """
    Sends a request to the local Qwen endpoint and returns the generated text
    along with the average token log-probability and token count.
    """
    max_tokens = {1: 500, 2: 800, 3: 1700}.get(step_num, 500)
    url = "http://localhost:8000/v1/chat/completions"
    payload = {
        "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": max_tokens,
        "temperature": temperature,
        "logprobs": True,
    }
    # TODO: resp = requests.post(url, json=payload).json()
    resp = requests.post(url, json=payload).json()

    # TODO: Extract text = resp['choices'][0]['message']['content'].strip()
    text = resp['choices'][0]['message']['content'].strip()

    # TODO: Gather token_logprobs from resp['choices'][*]['logprobs']['content'][*]['logprob']
    token_logprobs = [choice['logprob'] for choice in resp['choices'][0]['logprobs']['content']]

    # TODO: Compute avg_logprob and num_token, then return (text, avg_logprob, num_token)
    num_token = len(token_logprobs)
    avg_logprob = sum(token_logprobs) / num_token if num_token > 0 else 0.0
    return text, avg_logprob, num_token


class BeamCandidate:
    def __init__(
        self,
        sequence: str,
        cumulative_log_prob: float,
        step_scores: List[float],
        finished: bool,
        num_token: int
    ):
        self.sequence = sequence
        self.cumulative_log_prob = cumulative_log_prob
        self.step_scores = step_scores
        self.finished = finished
        self.num_token = num_token

    def __repr__(self):
        return (
            f"BeamCandidate(score={self.cumulative_log_prob:.3f}, "
            f"finished={self.finished}, sequence=[...])"
        )


def generate_reasoning_steps(
    context: str,
    step_num: int,
    top_k: int,
    use_logprob: bool = True,
    problem: Optional[str] = None
):
    """
    Generate top_k candidate continuations for the current reasoning step.
    """
    candidates = []
    for i in range(top_k):
        # TODO: Build `suffix` based on step_num (1=understand, 2=plan, 3=solve)
        suffix = {
            1: "Let's understand the problem.", # understand
            2: "Let's plan how to solve it.", # plan
            3: "Let's solve it and put the final answer within \\boxed{}." # solve
        }.get(step_num, "")

        # TODO: prompt = context + suffix
        prompt = context + "\n" + suffix

        # TODO: output, avg_token_prob, num_token = call_qwen_model_raw(prompt, step_num)
        output, avg_token_prob, num_token = call_qwen_model_raw(prompt, step_num)

        # TODO: if use_logprob: score = avg_token_prob else: assert problem, score = score_with_gpt(problem, output)
        if use_logprob:
            score = avg_token_prob
        else:
            assert problem is not None, "Problem must be provided for GPT scoring"
            score = score_with_gpt(problem, output)

        # TODO: finished = "\\boxed{" in output
        finished = "\\boxed{" in output

        # TODO: candidates.append((output.strip(), score, num_token, finished))
        candidates.append((output.strip(), score, num_token, finished))

    return candidates


def beam_search(
    init_problem_prompt: str,
    beam_width: int = 3,
    max_steps: int = 3,
    top_k: int = 2,
    use_logprob: bool = True
):
    """
    Beam search over reasoning steps. If use_logprob=False, uses GPT verifier
    to score each node instead of token log-prob.
    """
    # TODO: Extract `problem` from init_problem_prompt
    problem = init_problem_prompt.split('\n')[-1]

    # TODO: initial = BeamCandidate(sequence=init_problem_prompt, cumulative_log_prob=0.0, step_scores=[], finished=False, num_token=0)
    initial = BeamCandidate(sequence=init_problem_prompt, cumulative_log_prob=0.0, step_scores=[], finished=False, num_token=0)

    # TODO: beams = [initial]
    beams = [initial]

    for step in range(1, max_steps + 1):
        new_beams = []
        # TODO: For each cand in beams:
        for cand in beams:
            if cand.finished:
                new_beams.append(cand)
                continue

            step_cands = generate_reasoning_steps(cand.sequence, step, top_k, use_logprob, problem)

            for text, score, n_tok, finished in step_cands:
                seq = cand.sequence + "\n" + text
                total_tokens = cand.num_token + n_tok
                cum = ((cand.cumulative_log_prob * cand.num_token) + score * n_tok) / total_tokens
                new_beams.append(BeamCandidate(seq, cum, cand.step_scores + [score], finished, total_tokens))

        # Sort new_beams by cumulative_log_prob desc and slice to beam_width
        new_beams.sort(key=lambda x: x.cumulative_log_prob, reverse=True)

        # beams = new_beams
        beams = new_beams[:beam_width]

        # If all beams finished: break
        if all(b.finished for b in beams):
            break

    finished = [b for b in beams if b.finished]

    # TODO: best = max(finished, key=...) if finished else beams[0]
    best = max(finished, key=lambda b: b.cumulative_log_prob) if finished else beams[0]

    # TODO: return best
    return best


def run_qwen_beam_search(
    problem: str,
    beam_width: int,
    max_steps: int,
    top_k: int,
    log_level,
    use_logprob: bool = True
):
    """
    Performs beam search and extracts the final boxed answer.
    """
    prompt = f"Consider this problem:\n{problem}"

    best = beam_search(prompt, beam_width, max_steps, top_k, use_logprob)
    if best.finished:
        ans = extract_answer(best.sequence)
        print(f"\nExtracted Final Answer: {ans}")
        return ans
    else:
        print("No final answer found.")
        return None


# Evaluate beam search
* modify response generation part to evalute this method.

In [None]:
MAX_SAMPLE_TEST = 30


def evaluate_beam_search(use_logprob: bool = True):
    """
    Evaluate beam search on MATH‑500, toggling between log‑prob scoring
    and GPT/Gemini–based verification for each reasoning node.
    """
    os.makedirs("results", exist_ok=True)

    suffix = "logprob" if use_logprob else "gpt"
    results_file = f"evaluation_results_math500_deepseek_beam_search_{suffix}.json"

    dataset = load_math500_dataset()
    existing_results = load_existing_results(results_file)
    processed_indexes = {r['index'] for r in existing_results}

    cnt = 0
    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed_indexes or idx >= MAX_SAMPLE_TEST:
            continue

        problem_text   = item['problem']
        correct_answer = extract_answer(item['solution'])

        # Run beam search with the desired scoring method
        response = run_qwen_beam_search(
            problem     = problem_text,
            beam_width  = 3,
            max_steps   = 3,
            top_k       = 2,
            log_level   = 1,              # existing parameter
            use_logprob = use_logprob     # True = token log‑prob, False = GPT verifier
        )
        predicted_answer = response

        is_correct = compare_answers(correct_answer, predicted_answer)
        save_result(results_file, {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        })

        if is_correct:
            cnt += 1
        print(f"corrects: {cnt} / {idx+1}")

    final_results = load_existing_results(results_file)
    analyze_results(final_results)



In [None]:
# Example usage:
evaluate_beam_search(use_logprob=True)   # pick by avg token log‑prob

Evaluating problems:   0%|          | 1/500 [01:53<15:43:51, 113.49s/it]


Extracted Final Answer: \left(3, \frac{\pi}{2}\right)
corrects: 1 / 1


Evaluating problems:   0%|          | 2/500 [06:04<26:53:18, 194.37s/it]


Extracted Final Answer: p - q
corrects: 2 / 2


Evaluating problems:   1%|          | 3/500 [08:22<23:18:34, 168.84s/it]


Extracted Final Answer: \dfrac{14}{3}
corrects: 3 / 3


Evaluating problems:   1%|          | 4/500 [10:46<21:54:45, 159.04s/it]


Extracted Final Answer: 9
corrects: 4 / 4


Evaluating problems:   1%|          | 5/500 [14:48<25:57:00, 188.73s/it]


Extracted Final Answer: Evelyn
corrects: 5 / 5


Evaluating problems:   1%|          | 6/500 [16:22<21:30:17, 156.72s/it]


Extracted Final Answer: 42
corrects: 6 / 6


Evaluating problems:   1%|▏         | 7/500 [20:28<25:25:28, 185.66s/it]


Extracted Final Answer: 27
corrects: 7 / 7


Evaluating problems:   2%|▏         | 8/500 [23:39<25:36:33, 187.39s/it]


Extracted Final Answer: 90
corrects: 8 / 8


Evaluating problems:   2%|▏         | 9/500 [25:30<22:19:13, 163.65s/it]


Extracted Final Answer: 3\sqrt{13}
corrects: 9 / 9


Evaluating problems:   2%|▏         | 10/500 [29:48<26:14:43, 192.82s/it]

No final answer found.
corrects: 9 / 10


Evaluating problems:   2%|▏         | 11/500 [33:57<28:30:21, 209.86s/it]


Extracted Final Answer: 2220
corrects: 10 / 11


Evaluating problems:   2%|▏         | 12/500 [38:17<30:32:34, 225.32s/it]

No final answer found.
corrects: 10 / 12


Evaluating problems:   3%|▎         | 13/500 [42:33<31:42:45, 234.43s/it]


Extracted Final Answer: 284
corrects: 11 / 13


Evaluating problems:   3%|▎         | 14/500 [43:16<23:50:32, 176.61s/it]


Extracted Final Answer: 5
corrects: 12 / 14


Evaluating problems:   3%|▎         | 15/500 [46:05<23:28:23, 174.23s/it]


Extracted Final Answer: \sqrt{51}
corrects: 13 / 15


Evaluating problems:   3%|▎         | 16/500 [50:24<26:53:25, 200.01s/it]


Extracted Final Answer: 6 - 5i
corrects: 14 / 16


Evaluating problems:   3%|▎         | 17/500 [53:16<25:41:00, 191.43s/it]


Extracted Final Answer: -50
corrects: 15 / 17


Evaluating problems:   4%|▎         | 18/500 [57:34<28:19:03, 211.50s/it]

No final answer found.
corrects: 15 / 18


Evaluating problems:   4%|▍         | 19/500 [1:01:51<30:04:08, 225.05s/it]

No final answer found.
corrects: 15 / 19


Evaluating problems:   4%|▍         | 20/500 [1:06:07<31:14:42, 234.34s/it]

No final answer found.
corrects: 15 / 20


Evaluating problems:   4%|▍         | 21/500 [1:07:35<25:19:33, 190.34s/it]


Extracted Final Answer: 6 + 9i
corrects: 15 / 21


Evaluating problems:   4%|▍         | 22/500 [1:11:50<27:52:59, 210.00s/it]

No final answer found.
corrects: 15 / 22


Evaluating problems:   5%|▍         | 23/500 [1:14:57<26:53:08, 202.91s/it]


Extracted Final Answer: 5
corrects: 16 / 23


Evaluating problems:   5%|▍         | 24/500 [1:18:44<27:48:34, 210.32s/it]


Extracted Final Answer: 5
corrects: 16 / 24


Evaluating problems:   5%|▌         | 25/500 [1:23:01<29:36:11, 224.36s/it]


Extracted Final Answer: 10
corrects: 17 / 25


Evaluating problems:   5%|▌         | 26/500 [1:27:20<30:53:09, 234.58s/it]


Extracted Final Answer: 1
corrects: 17 / 26


Evaluating problems:   5%|▌         | 27/500 [1:31:37<31:43:22, 241.44s/it]

No final answer found.
corrects: 17 / 27


Evaluating problems:   6%|▌         | 28/500 [1:33:01<25:25:57, 193.98s/it]


Extracted Final Answer: \$78
corrects: 18 / 28


Evaluating problems:   6%|▌         | 29/500 [1:36:13<25:19:41, 193.59s/it]


Extracted Final Answer: -2 + 7i
corrects: 18 / 29


Evaluating problems: 100%|██████████| 500/500 [1:39:13<00:00, 11.91s/it]   


Extracted Final Answer: 225
corrects: 19 / 30

=== Results Summary ===
Total problems: 30
Correct answers: 19
Accuracy: 63.33%

=== Incorrect Problems ===
Problem 9:
Expected: 4
Predicted: None
---
Problem 11:
Expected: \frac{3}{56}
Predicted: None
---
Problem 17:
Expected: \pi
Predicted: None
---
Problem 18:
Expected: 28
Predicted: None
---
Problem 19:
Expected: 3
Predicted: None
---
Problem 20:
Expected: 6+9i
Predicted: 6 + 9i
---
Problem 21:
Expected: 13535
Predicted: None
---
Problem 23:
Expected: x=5
Predicted: 5
---
Problem 25:
Expected: 1,-2
Predicted: 1
---
Problem 26:
Expected: 144
Predicted: None
---
Problem 28:
Expected: -2 + 7i
Predicted: -2 + 7i
---





In [None]:
evaluate_beam_search(use_logprob=False)  # pick by GPT/Gemini verification

Evaluating problems:   0%|          | 1/500 [01:42<14:10:50, 102.31s/it]


Extracted Final Answer: (3, \frac{\pi}{2})
corrects: 1 / 1


Evaluating problems:   0%|          | 2/500 [05:42<25:24:43, 183.70s/it]


Extracted Final Answer: p - q
corrects: 2 / 2


Evaluating problems:   1%|          | 3/500 [07:58<22:18:56, 161.64s/it]


Extracted Final Answer: \dfrac{14}{3}
corrects: 3 / 3


Evaluating problems:   1%|          | 4/500 [09:56<19:55:50, 144.66s/it]


Extracted Final Answer: 9
corrects: 4 / 4


Evaluating problems:   1%|          | 5/500 [14:17<25:37:05, 186.31s/it]


Extracted Final Answer: Evelyn
corrects: 5 / 5


Evaluating problems:   1%|          | 6/500 [15:30<20:16:36, 147.77s/it]


Extracted Final Answer: 42
corrects: 6 / 6


Evaluating problems:   1%|▏         | 7/500 [19:16<23:46:04, 173.56s/it]


Extracted Final Answer: 27
corrects: 7 / 7


Evaluating problems:   2%|▏         | 8/500 [23:22<26:51:20, 196.50s/it]


Extracted Final Answer: 90
corrects: 8 / 8


Evaluating problems:   2%|▏         | 9/500 [25:55<24:56:54, 182.92s/it]


Extracted Final Answer: 3\sqrt{13}
corrects: 9 / 9


Evaluating problems:   2%|▏         | 10/500 [30:23<28:27:49, 209.12s/it]

No final answer found.
corrects: 9 / 10


Evaluating problems:   2%|▏         | 11/500 [34:15<29:21:48, 216.17s/it]


Extracted Final Answer: 2220
corrects: 10 / 11


Evaluating problems:   2%|▏         | 12/500 [38:43<31:27:00, 232.01s/it]


Extracted Final Answer: \dfrac{1}{9}
corrects: 10 / 12


Evaluating problems:   3%|▎         | 13/500 [42:52<32:03:30, 236.98s/it]


Extracted Final Answer: 284
corrects: 11 / 13


Evaluating problems:   3%|▎         | 14/500 [43:24<23:39:12, 175.21s/it]


Extracted Final Answer: 5
corrects: 12 / 14


Evaluating problems:   3%|▎         | 15/500 [46:25<23:49:56, 176.90s/it]


Extracted Final Answer: \sqrt{51}
corrects: 13 / 15


Evaluating problems:   3%|▎         | 16/500 [50:50<27:22:16, 203.59s/it]


Extracted Final Answer: 6 - 5i
corrects: 14 / 16


Evaluating problems:   3%|▎         | 17/500 [55:10<29:34:33, 220.44s/it]


Extracted Final Answer: -50
corrects: 15 / 17


Evaluating problems:   4%|▎         | 18/500 [59:37<31:22:49, 234.38s/it]

No final answer found.
corrects: 15 / 18


Evaluating problems:   4%|▍         | 19/500 [1:04:04<32:38:36, 244.32s/it]

No final answer found.
corrects: 15 / 19


Evaluating problems:   4%|▍         | 20/500 [1:08:29<33:23:16, 250.41s/it]

No final answer found.
corrects: 15 / 20


Evaluating problems:   4%|▍         | 21/500 [1:10:30<28:09:05, 211.58s/it]


Extracted Final Answer: 6 + 9i
corrects: 15 / 21


Evaluating problems:   4%|▍         | 22/500 [1:14:57<30:18:50, 228.31s/it]

No final answer found.
corrects: 15 / 22


Evaluating problems:   5%|▍         | 23/500 [1:18:21<29:17:16, 221.04s/it]


Extracted Final Answer: 5
corrects: 16 / 23


Evaluating problems:   5%|▍         | 24/500 [1:21:38<28:14:31, 213.60s/it]


Extracted Final Answer: 5
corrects: 16 / 24


Evaluating problems:   5%|▌         | 25/500 [1:25:37<29:11:29, 221.24s/it]


Extracted Final Answer: 10
corrects: 17 / 25


Evaluating problems:   5%|▌         | 26/500 [1:30:05<30:58:30, 235.25s/it]


Extracted Final Answer: 1
corrects: 17 / 26


Evaluating problems:   5%|▌         | 27/500 [1:34:32<32:09:28, 244.75s/it]

No final answer found.
corrects: 17 / 27


Evaluating problems:   6%|▌         | 28/500 [1:36:25<26:54:32, 205.24s/it]


Extracted Final Answer: 78
corrects: 18 / 28


Evaluating problems:   6%|▌         | 29/500 [1:39:11<25:19:32, 193.57s/it]


Extracted Final Answer: -2 + 7i
corrects: 18 / 29


Evaluating problems: 100%|██████████| 500/500 [1:42:27<00:00, 12.30s/it]   


Extracted Final Answer: 225
corrects: 19 / 30

=== Results Summary ===
Total problems: 30
Correct answers: 19
Accuracy: 63.33%

=== Incorrect Problems ===
Problem 9:
Expected: 4
Predicted: None
---
Problem 11:
Expected: \frac{3}{56}
Predicted: \dfrac{1}{9}
---
Problem 17:
Expected: \pi
Predicted: None
---
Problem 18:
Expected: 28
Predicted: None
---
Problem 19:
Expected: 3
Predicted: None
---
Problem 20:
Expected: 6+9i
Predicted: 6 + 9i
---
Problem 21:
Expected: 13535
Predicted: None
---
Problem 23:
Expected: x=5
Predicted: 5
---
Problem 25:
Expected: 1,-2
Predicted: 1
---
Problem 26:
Expected: 144
Predicted: None
---
Problem 28:
Expected: -2 + 7i
Predicted: -2 + 7i
---





## Self-Refinement

This cell implements a self-refinement approach to solving math problems. Initially, it generates a solution using a fixed system prompt that enforces a step-by-step reasoning process and a final answer format enclosed in `\boxed{}`. Then, through iterative feedback, the model is asked to analyze its own output and refine it if necessary. This loop ensures that the final answer is both correct and clearly formatted.


In [None]:
import re
import requests

SYSTEM_PROMPT = '''You are solving mathematics problems.

Please think step by step.

Important: Always end your solution with the final answer in this format:

\\[
\\boxed{your_answer_here}
\\]

The entire answer should be contained completely within the \\boxed{} command.'''


def generate_content(prompt: str) -> str:
    """
    Sends `prompt` to the local Qwen endpoint and returns the generated text.
    """
    url = "http://localhost:8000/v1/chat/completions"
    payload = {
        "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 1504,
        "temperature": 0.3,
    }
    # TODO: Send HTTP POST to `url` with `payload`, parse the JSON response
    resp = requests.post(url, json=payload).json()

    # TODO: Extract `content` from `resp['choices'][0]['message']['content']` and strip whitespace
    text = resp['choices'][0]['message']['content'].strip()

    # TODO: Return the resulting string
    return text


FEEDBACK_SYSTEM_PROMPT = """
Given the problem and the current solution attempt, provide constructive feedback.

[PROBLEM]
{prompt}

[CURRENT_SOLUTION]
{current_output}

Please provide:
1. [FEEDBACK] Detailed feedback about any issues or improvements.
2. [REFINEMENT_NEEDED] Answer "yes" if the solution needs refinement, otherwise "no".

Use the exact format:
[FEEDBACK] your feedback here
[REFINEMENT_NEEDED] yes/no
"""

REFINE_SYSTEM_PROMPT = """
You previously solved the following problem:

{prompt}

Here was your attempt:
{current_output}

Here is the feedback you received:
{feedback}

Please revise your solution accordingly.
"""

def self_refine(problem: str, max_iter: int = 2) -> Optional[str]:
    """
    Iteratively refines the model’s output on `problem` using feedback loops.
    """
    # TODO: Build initial `prompt` by concatenating SYSTEM_PROMPT and `problem`
    prompt = SYSTEM_PROMPT + "\n" + problem

    # TODO: Call `generate_content(prompt)` to get `current_output`
    current_output = generate_content(prompt)

    for iteration in range(max_iter):
        # TODO: Construct `feedback_prompt` that includes:
        #         - the original `prompt`
        #         - the `current_output`
        #         - instructions to output [FEEDBACK] and [REFINEMENT_NEEDED]
        feedback_prompt = FEEDBACK_SYSTEM_PROMPT.format(prompt=prompt, current_output=current_output)

        # TODO: Call `generate_content(feedback_prompt)` → `feedback_response`
        feedback_response = generate_content(feedback_prompt)

        # TODO: Use `re.search` with pattern r"\[(?i:feedback)\](.*?)\[(?i:refinement_needed)\](.*)"
        #       to extract `feedback` and `refinement_flag`
        # TODO: Determine `refinement_needed` (default True, set False if flag is "no" or feedback contains stop phrases)
        match = re.search(r"\[(?i:feedback)\](.*?)\[(?i:refinement_needed)\](.*)", feedback_response, re.DOTALL)
        if match:
            feedback = match.group(1).strip()
            refinement_flag = match.group(2).strip().lower()
            refinement_needed = refinement_flag not in {"no", "false"}
        else:
            feedback = ""
            refinement_needed = True

        # TODO: If `not refinement_needed`: break out of loop
        if not refinement_needed:
            break

        # TODO: Build `refine_prompt` that includes:
        #         - the original `prompt`
        #         - the `current_output`
        #         - the extracted `feedback`
        refine_prompt = REFINE_SYSTEM_PROMPT.format(
            prompt = prompt,
            current_output = current_output,
            feedback = feedback
        )

        # TODO: Call `generate_content(refine_prompt)` → `refined_output`
        refined_output = generate_content(refine_prompt)

        # TODO: If `refined_output.strip() == current_output.strip()`: break
        if refined_output.strip() == current_output.strip():
            break

        # TODO: Otherwise, set `current_output = refined_output`
        current_output = refined_output

    # TODO: Use `extract_answer(current_output)` to get the final boxed answer
    # TODO: Return that answer (or None if no boxed answer found)
    final = extract_answer(current_output)
    return final


# Evaluate self refiner
* modify response generation part to evalute this method.

In [None]:
MAX_SAMPLE_TEST = 30

def evaluate_self_refiner():
    os.makedirs("results", exist_ok=True)
    results_file = "evaluation_results_math500_deepseek_self_refiner.json"
    dataset = load_math500_dataset()
    existing_results = load_existing_results(results_file)
    processed_indexes = {result['index'] for result in existing_results}
    cnt = 0
    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed_indexes :
            continue
        if idx >= MAX_SAMPLE_TEST:
          break
        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])
        ##########################################################
        response = self_refine(problem_text,3)
        predicted_answer = response
        ##########################################################
        is_correct = compare_answers(correct_answer, predicted_answer)
        result = {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        }
        save_result(results_file, result)
        if is_correct:
          cnt += 1
        print(f"corrects :  {cnt} idx: {idx}")
    final_results = load_existing_results(results_file)
    analyze_results(final_results)

In [None]:
evaluate_self_refiner()

Evaluating problems:   0%|          | 1/500 [01:23<11:31:08, 83.10s/it]

corrects :  1 idx: 0


Evaluating problems:   0%|          | 2/500 [02:31<10:19:59, 74.70s/it]

corrects :  2 idx: 1


Evaluating problems:   1%|          | 3/500 [03:35<9:36:23, 69.58s/it] 

corrects :  3 idx: 2


Evaluating problems:   1%|          | 4/500 [04:06<7:28:42, 54.28s/it]

corrects :  4 idx: 3


Evaluating problems:   1%|          | 5/500 [05:59<10:22:52, 75.50s/it]

corrects :  5 idx: 4


Evaluating problems:   1%|          | 6/500 [06:12<7:27:22, 54.34s/it] 

corrects :  6 idx: 5


Evaluating problems:   1%|▏         | 7/500 [07:38<8:50:21, 64.55s/it]

corrects :  7 idx: 6


Evaluating problems:   2%|▏         | 8/500 [09:26<10:44:07, 78.55s/it]

corrects :  8 idx: 7


Evaluating problems:   2%|▏         | 9/500 [09:44<8:07:20, 59.55s/it] 

corrects :  9 idx: 8


Evaluating problems:   2%|▏         | 10/500 [12:53<13:31:28, 99.36s/it]

corrects :  9 idx: 9


Evaluating problems:   2%|▏         | 11/500 [13:30<10:55:24, 80.42s/it]

corrects :  10 idx: 10


Evaluating problems:   2%|▏         | 12/500 [14:22<9:44:29, 71.86s/it] 

corrects :  10 idx: 11


Evaluating problems:   3%|▎         | 13/500 [16:59<13:12:46, 97.67s/it]

corrects :  10 idx: 12


Evaluating problems:   3%|▎         | 14/500 [17:14<9:46:51, 72.45s/it] 

corrects :  11 idx: 13


Evaluating problems:   3%|▎         | 15/500 [17:45<8:06:59, 60.25s/it]

corrects :  12 idx: 14


Evaluating problems:   3%|▎         | 16/500 [20:52<13:12:01, 98.18s/it]

corrects :  12 idx: 15


Evaluating problems:   3%|▎         | 17/500 [21:31<10:48:24, 80.55s/it]

corrects :  12 idx: 16


Evaluating problems:   4%|▎         | 18/500 [23:13<11:38:53, 87.00s/it]

corrects :  13 idx: 17


Evaluating problems:   4%|▍         | 19/500 [26:18<15:33:03, 116.39s/it]

corrects :  13 idx: 18


Evaluating problems:   4%|▍         | 20/500 [29:25<18:19:48, 137.48s/it]

corrects :  13 idx: 19


Evaluating problems:   4%|▍         | 21/500 [29:53<13:55:30, 104.66s/it]

corrects :  13 idx: 20


Evaluating problems:   4%|▍         | 22/500 [32:59<17:09:34, 129.24s/it]

corrects :  13 idx: 21


Evaluating problems:   5%|▍         | 23/500 [36:06<19:23:34, 146.36s/it]

corrects :  13 idx: 22


Evaluating problems:   5%|▍         | 24/500 [38:29<19:14:35, 145.54s/it]

corrects :  13 idx: 23


Evaluating problems:   5%|▌         | 25/500 [39:13<15:09:38, 114.90s/it]

corrects :  14 idx: 24


Evaluating problems:   5%|▌         | 26/500 [42:20<18:00:07, 136.73s/it]

corrects :  14 idx: 25


Evaluating problems:   5%|▌         | 27/500 [45:28<19:57:19, 151.88s/it]

corrects :  14 idx: 26


Evaluating problems:   6%|▌         | 28/500 [46:33<16:30:25, 125.90s/it]

corrects :  15 idx: 27


Evaluating problems:   6%|▌         | 29/500 [47:19<13:19:40, 101.87s/it]

corrects :  15 idx: 28


Evaluating problems:   6%|▌         | 30/500 [48:40<12:42:38, 97.36s/it]

corrects :  16 idx: 29

=== Results Summary ===
Total problems: 30
Correct answers: 16
Accuracy: 53.33%

=== Incorrect Problems ===
Problem 9:
Expected: 4
Predicted: None
---
Problem 11:
Expected: \frac{3}{56}
Predicted: None
---
Problem 12:
Expected: 284
Predicted: None
---
Problem 15:
Expected: 6 - 5i
Predicted: None
---
Problem 16:
Expected: -50
Predicted: None
---
Problem 18:
Expected: 28
Predicted: 56^{\circ}
---
Problem 19:
Expected: 3
Predicted: None
---
Problem 20:
Expected: 6+9i
Predicted: 6 + 9i
---
Problem 21:
Expected: 13535
Predicted: None
---
Problem 22:
Expected: 5
Predicted: None
---
Problem 23:
Expected: x=5
Predicted: 5
---
Problem 25:
Expected: 1,-2
Predicted: None
---
Problem 26:
Expected: 144
Predicted: None
---
Problem 28:
Expected: -2 + 7i
Predicted: -2 + 7i
---





# Part 2: 
Implementing A*, Monte Carlo Tree Search (MCTS), and Tree of Thoughts (ToT)

Ok so now every thing is ready to start part 2, This part aims to explore three sophisticated search and reasoning algorithms—**A\***, **Monte Carlo Tree Search (MCTS)**, and **Tree of Thoughts (ToT)**—to solve challenging mathematical problems, specifically using the MATH-500 dataset and an LLM (Language Model). Before diving into the implementation, we provide a comprehensive overview of each algorithm, highlighting their core mechanisms, practical considerations, and potential challenges.

---

## 🌟 1. A* Search Algorithm

**A*** is an informed search algorithm designed for efficiently finding the shortest path or optimal solution in a search space using heuristics.

### 🔹 Core Principles:
- **Best-first search:** Expands nodes based on a cost function, \(f(n)\), prioritizing paths that seem closer to a goal.
- **Heuristic evaluation:** Uses a heuristic function \(h(n)\) to estimate the cost from the current node to the goal.

### 🔹 Components:
- **Cost function \(f(n) = g(n) + h(n)\)**, where:
  - \(g(n)\): Actual cost from start node to node \(n\).
  - \(h(n)\): Estimated cost from node \(n\) to the goal (heuristic).

### 🔹 Practical Considerations:
- Heuristic function must be **efficient** and **accurate**.
- A good heuristic drastically reduces computation time and search complexity.

### ⚠️ Challenges in Implementation:
- **Designing an effective heuristic:**
  - Challenge to accurately estimate "distance" from partial reasoning steps to the solution.
  - Requires leveraging language models to score plausibility.
- **Computational efficiency:**
  - Heuristic evaluation via LLM queries could be computationally costly if not managed carefully.
- **Admissibility and consistency:**
  - Ideally, heuristic must be admissible (never overestimates the true cost) to guarantee optimality.

---

## 🎲 2. Monte Carlo Tree Search (MCTS)

**MCTS** is a probabilistic algorithm widely used in decision-making scenarios, especially effective in complex problems with uncertain outcomes, such as mathematical reasoning guided by language models.

### 🔹 Core Principles:
MCTS explores decision trees using **randomized simulations** (rollouts) and statistical sampling.

**Four main steps in MCTS**:
1. **Selection**: Uses UCT (Upper Confidence Bound) to balance exploration and exploitation.
2. **Expansion**: Adds new unexplored nodes to the tree.
3. **Simulation**: Conducts rollouts from newly expanded nodes to estimate potential outcomes.
4. **Backpropagation**: Updates statistical measures based on simulation results.

### 🔹 Components:
- **UCT formula** for selection:
  $$
  \text{UCT} = \frac{w_i}{n_i} + C\sqrt{\frac{\ln N}{n_i}}
  $$
  - \(w_i\): Total rewards.
  - \(n_i\): Visits to node \(i\).
  - \(N\): Visits to parent node.
  - \(C\): Exploration constant (usually \(\sqrt{2}\)).

- **Rollout (simulation)**:
  - Typically involves letting the LLM complete reasoning steps to the end, evaluating correctness.

### 🔹 Practical Considerations:
- Effective rollout policies significantly impact accuracy and efficiency.
- Balance exploration (testing new reasoning paths) and exploitation (refining known good solutions).

### ⚠️ Challenges in Implementation:
- **Computational overhead**:
  - Running multiple LLM-based rollouts per node can be slow and computationally expensive.
- **Optimal parameter tuning**:
  - Choosing the exploration constant \(C\) and number of simulations impacts performance significantly.
- **Quality of simulation outcomes**:
  - Poor rollout outcomes (random or inaccurate completions) can misguide the search tree.

---

## 🌳 3. Tree of Thoughts (ToT)

**Tree of Thoughts (ToT)** is specifically designed for structured reasoning tasks with language models, extending their capabilities through explicit evaluation and pruning of reasoning paths.

### 🔹 Core Principles:
- Generate multiple candidate "thoughts" (reasoning paths).
- Evaluate each thought explicitly (often through LLM-based scoring or consistency checking).
- Iteratively prune weaker reasoning paths, keeping the most promising solutions.

### 🔹 Components:
- **Thought generation**: Multiple candidate reasoning steps generated at each node.
- **Thought evaluation**: Explicit scoring (via LLMs) to judge mathematical correctness or plausibility.
- **Pruning**: Remove less promising reasoning branches based on evaluation.

### 🔹 Practical Considerations:
- Explicit node evaluation adds a structured layer of reasoning not present in simpler methods.
- Allows LLMs to reason more deliberately by systematically exploring and eliminating alternatives.

### ⚠️ Challenges in Implementation:
- **Evaluation complexity**:
  - Frequent explicit evaluations by LLM can slow the process.
  - Requires efficient prompting and scoring techniques.
- **Self-consistency**:
  - Maintaining logical consistency across multiple branches can be difficult, especially for complex math problems.
- **Scaling**:
  - Managing multiple branches of reasoning can quickly become computationally expensive without careful control.

---


**Next, we will begin the practical implementation step-by-step.**


# Implementing Tree of Thoughts (ToT)

Before writing any code, it’s essential to map out the steps and functions we need for our Tree of Thoughts (ToT) implementation. In ToT, we aren’t just following one linear chain of reasoning but instead generating a tree of candidate reasoning paths (“thoughts”), evaluating them, and then expanding the most promising ones. We can leverage our existing helper functions (such as those for extracting and normalizing answers) as part of the evaluation process.

Below is an outline of the steps and functions we’ll need:

---

## 1. **Node Representation**

We need a way to represent each node in our reasoning tree. A node could include:
- **Current reasoning text**: The partial solution or thought generated so far.
- **Candidate thoughts**: A list of potential next steps (children nodes).
- **Evaluation score**: A score indicating how promising the node is (based on plausibility or correctness).
- **Metadata**: Such as depth in the tree or a reference to the parent node.

**Potential Functions/Classes:**
- `class Node`: A class that encapsulates the above properties.
- `add_child(self, child_node)`: A method to attach a new candidate thought.

---

## 2. **Candidate Thought Generation**

This function will use the LLM (via our `get_llm_response` function) to generate multiple candidate reasoning steps given a node's current state.

**Key Points:**
- The prompt should be carefully crafted to ask the LLM for alternative reasoning steps.
- We can use techniques such as few-shot prompting to guide the LLM in generating diverse thoughts.

**Potential Function:**
- `def generate_candidate_thoughts(node: Node, num_candidates: int) -> List[str]:`
  - This function takes the current reasoning state from `node` and returns a list of candidate thoughts (as strings).

---

## 3. **Candidate Thought Evaluation**

Once we have multiple candidate thoughts, we need to score them. The evaluation could be based on:
- **Model’s self-assessment:** Ask the LLM to rate each candidate on a scale (e.g., 1 to 10) for mathematical plausibility or correctness.
- **Heuristics based on helper functions:** Use helper functions like `extract_answer` and `normalize_answer` to check whether a candidate thought moves closer to a correct answer or simplifies the expression.

**Potential Function:**
- `def evaluate_candidate_thought(candidate: str) -> float:`
  - This function might prompt the LLM with the candidate reasoning step, asking, “How plausible or correct is this step?” and return a numeric score.
  - Alternatively, it might combine an LLM score with our own heuristic checks.

---

## 4. **Pruning and Selection**

After evaluating the candidates, we must select the most promising ones for further expansion. Pruning involves:
- Ranking candidate thoughts by their evaluation score.
- Keeping only the top N candidates (to control the tree size).

**Potential Function:**
- `def select_best_candidates(candidates: List[str], scores: List[float], top_n: int) -> List[str]:`
  - This function will combine the candidate list and their scores to select the best ones for further exploration.

---

## 5. **Node Expansion**

For each node, the process is:
1. **Generate candidate thoughts** using the generation function.
2. **Evaluate each candidate** using the evaluation function.
3. **Select the best candidate(s)** using the pruning/selection function.
4. **Expand the tree** by creating child nodes for each selected candidate.

**Potential Function:**
- `def expand_node(node: Node, num_candidates: int, top_n: int) -> None:`
  - This function integrates candidate generation, evaluation, and pruning to add new child nodes to the given `node`.

---

## 6. **Stopping Criteria**

We need clear criteria for when to stop expanding the tree:
- **Complete solution found:** When a node contains a complete solution (e.g., using `extract_answer` to verify that a boxed answer exists).
- **Depth or resource limits:** When a maximum depth is reached or computational resources are constrained.

**Potential Function:**
- `def is_solution(node: Node) -> bool:`
  - This function checks if a node’s reasoning contains a valid, complete answer.
- `def stop_expansion(node: Node, max_depth: int) -> bool:`
  - This function checks if the node has reached the maximum allowed depth.

---

## 7. **Integration with Helper Functions**

Our previous helper functions play a crucial role in the ToT implementation:
- **Extracting and normalizing answers:**  
  Use `extract_answer` and `normalize_answer` to interpret candidate outputs and compare them against the expected solution.
- **Comparison functions:**  
  Use `compare_answers` to help decide if a candidate thought is moving in the right direction.
- **LLM response function:**  
  `get_llm_response` is used both for generating candidate thoughts and possibly for scoring them.

---

## 8. **Overall ToT Flow**

Putting it all together, here is an outline of the overall ToT process:
1. **Initialize the root node** with the initial problem statement.
2. **While the stopping criteria are not met:**
   - For the current node, generate candidate thoughts.
   - Evaluate each candidate.
   - Select and expand the best candidates to form new child nodes.
3. **Once a candidate thought leads to a complete solution:**
   - Use helper functions to verify correctness.
   - Return or record the successful reasoning path.

---

In [None]:
import re
import time
from typing import List, Optional

def score_with_gpt(problem: str, candidate: str) -> float:
    """
    Ask a high‑quality LLM (via get_api_response) to rate the given
    final answer on a scale from 1 (poor) to 10 (excellent).
    """
    # TODO: build `eval_prompt` string using `problem` and `candidate`
    # resp = get_api_response(eval_prompt).strip()
    # TODO: use `re.search` to extract the first numeric score from `resp`
    # return float(match) or 0.0 on failure
    eval_prompt = (
        "You are a precise mathematics grader. "
        f"Problem: {problem}\n\n"
        f"Candidate solution: {candidate}\n\n"
        "On a scale from 1 (completely incorrect) to 10 (perfectly correct), "
        "how accurate is this solution? Reply with only a number between 1 and 10."
    )
    # eval_prompt = f"""You are a math problem solver. Given this problem:
    #     {problem}

    #     Candidate solution step:
    #     {candidate}

    #     Rate the mathematical correctness and progress towards solution from 0 (completely wrong) to 10 (perfectly correct). Reply with only a number between 0 and 10."""

    resp = get_api_response(eval_prompt).strip()
    # match = re.search(r'\d+', resp)
    match = re.search(r'\d+(?:\.\d+)?', resp)
    # return float(match.group()) if match else 0.0
    return max(10.0, min(1.0, float(match.group()))) if match else 0.0

class Node:
    def __init__(self, state: str, depth: int = 0, parent: Optional['Node'] = None):
        self.state = state
        self.depth = depth
        self.parent = parent
        self.children: List[Node] = []      # easy: list to hold child nodes
        self.score: Optional[float] = None  # will be set once evaluated

    def add_child(self, child_node: 'Node'):
        self.children.append(child_node)

    def __repr__(self):
        # easy: show depth, score, and a truncated preview of `state`
        preview = self.state if len(self.state) < 50 else self.state[:47] + "..."
        return f"Node(depth={self.depth}, score={self.score}, state='{preview}')"



RETRY_PROMPT = """Fix this solution attempt to ensure it ends with a valid boxed answer.

Problem:
{root_state}

Attempt:
{state}

Return the full corrected solution with a \\boxed{{...}} answer."""


class TreeOfThoughts:
    def __init__(
        self,
        num_candidates: int = 3,
        top_n: int = 2,
        max_depth: int = 3,
        verbose: bool = False
    ):
        self.num_candidates = num_candidates
        self.top_n = top_n
        self.max_depth = max_depth
        self.verbose = verbose
        self.root_state: Optional[str] = None

    def generate_candidate_thoughts(self, node: Node) -> List[str]:
        """
        Prompt the LLM to produce `num_candidates` boxed answers for `node.state`.
        """
        # TODO: compose prompt with node.state
        # TODO: call `get_llm_response(prompt)`
        # TODO: split into non-empty lines and return exactly `self.num_candidates` answers
        prompt = (
            "You are a mathematical problem solver. "
            f"Problem: {node.state}\n\n"
            f"Generate {self.num_candidates} distinct possible next steps or solutions "
            "to this problem. Each solution should end with a final answer in the LaTeX boxed format "
            "\\boxed{...}. Provide each solution on a new line."
        )
        response = get_llm_response(prompt)
        # answers = re.findall(r'\\boxed\{(.*?)\}', response) # if did not work test with answers
        # candidates = answers
        candidates = [line.strip() for line in response.split('\n') if line.strip()]

        # Ensure we have exactly num_candidates
        if len(candidates) < self.num_candidates:
            candidates += [candidates[-1]] * (self.num_candidates - len(candidates))
        return candidates[:self.num_candidates]
        # return [f"\\boxed{{{ans}}}" for ans in answers[:self.num_candidates]]

    def evaluate_candidate_thought(self, candidate: str) -> float:
        # simple wrapper around score_with_gpt
        assert self.root_state is not None
        return score_with_gpt(self.root_state, candidate)

    def select_best_candidates(self, candidates: List[str]) -> List[str]:
        # TODO: for each cand in candidates:
        #         - score = self.evaluate_candidate_thought(cand)
        #         - time.sleep(0.5)
        #       sort by score descending and return top `self.top_n`
        scored = []
        for cand in candidates:
            score = self.evaluate_candidate_thought(cand)
            time.sleep(0.5)  # Rate limit API calls
            scored.append((score, cand))

        scored.sort(reverse=True, key=lambda x: x[0])
        return [cand for score, cand in scored[:self.top_n]]

    def expand_node(self, node: Node) -> None:
        if self.verbose:
            print(f"\nExpanding depth {node.depth} state:\n{node.state}\n")
        # TODO: raw = self.generate_candidate_thoughts(node)
        # TODO: best = self.select_best_candidates(raw)
        raw = self.generate_candidate_thoughts(node)
        best = self.select_best_candidates(raw)

        # for ans in []:  # replace [] with `best`
        for ans in best:
            child = Node(state=ans, depth=node.depth + 1, parent=node)
            # TODO: child.score = self.evaluate_candidate_thought(ans)
            # TODO: if not boxed, retry up to 3 times with a strict prompt
            child.score = self.evaluate_candidate_thought(ans)

            # Ensure we have a boxed answer
            retries = 0
            while not self.is_solution(child) and retries < 3:
                fallback_prompt = RETRY_PROMPT.format(root_state=self.root_state, state=child.state)
                child.state = get_llm_response(fallback_prompt).strip()
                retries += 1

            node.add_child(child)
            if self.verbose:
                print(f"Added child: {child}")

    def is_solution(self, node: Node) -> bool:
        # TODO: return True if `extract_answer(node.state)` yields non-empty
        return extract_answer(node.state) is not None

    def stop_expansion(self, node: Node) -> bool:
        return node.depth >= self.max_depth

    def search(self, root_state: str) -> Node:
        """
        Build the tree until solutions found or max depth reached.
        """
        self.root_state = root_state
        root = Node(state=root_state, depth=0)
        frontier = [root]
        while frontier:
            node = frontier.pop(0)
            if self.is_solution(node) or self.stop_expansion(node):
                continue
            self.expand_node(node)
            frontier.extend(node.children)
        return root


# Testing the Tree of Thoughts Method with Minimal Hyperparameters

This cell demonstrates a test run of the Tree of Thoughts (ToT) framework using minimal hyperparameters. The goal is to ensure that a complete final answer (in the format `\boxed{...}`) is extracted from the model's output.

**Key Steps:**

- **Instantiate TOT:**  
  The TOT instance is created with `num_candidates=1`, `top_n=1`, and `max_depth=1` in verbose mode. This minimal setup is used for quick testing.

- **Run the Search:**  
  The TOT search is executed on the sample problem:  
  *"Solve the integral: \( \int_0^1 x^2 \, dx \)"*

- **Print the TOT Tree:**  
  A recursive function (`print_tree`) prints the entire search tree, allowing inspection of each node's state and depth.

- **Extract Final Answer:**  
  All nodes in the tree are collected. If a node is found that contains a final answer (determined via the helper function `extract_answer`), then the node's state is overwritten to display only the final answer in the correct boxed format.  
  If no such node is found, a fallback prompt forces the model to output the final answer exclusively.

This setup helps verify that the TOT framework correctly isolates and formats the final answer, ensuring the result is comparable to the expected output.


In [None]:
# Test the TreeOfThoughts with minimal settings.
# Generation of candidate answers still uses your primary LLM via get_llm_response,
# while evaluation/scoring uses only the GPT/Gemini verifier.

tot = TreeOfThoughts(num_candidates=1,
                     top_n=1,
                     max_depth=1,
                     verbose=True)

initial_problem = "Solve the integral: \\( \\int_0^1 x^2 \\, dx \\)"
tot_tree = tot.search(initial_problem)

# Helper to print the whole tree
def print_tree(node: Node, indent: str = ""):
    print(indent + repr(node))
    for child in node.children:
        print_tree(child, indent + "  ")

print_tree(tot_tree)

# Collect all nodes and find the first complete solution
def collect_all_nodes(node: Node) -> List[Node]:
    nodes = [node]
    for child in node.children:
        nodes.extend(collect_all_nodes(child))
    return nodes

all_nodes = collect_all_nodes(tot_tree)
solution_node = next((n for n in all_nodes if tot.is_solution(n)), None)

print("---")
if solution_node:
    final = extract_answer(solution_node.state)
    solution_node.state = f"\\boxed{{{final}}}"
    print("\nFinal Answer Found:")
    print(solution_node.state)
else:
    # If no solution node was produced, force a final answer via GPT verifier
    fallback_prompt = (
        f"Based on the problem: \"{initial_problem}\", please provide ONLY your final answer "
        "in the exact format \\boxed{...}."
    )
    forced = get_api_response(fallback_prompt).strip()
    final = extract_answer(forced) or forced
    forced = f"\\boxed{{{final}}}"
    print("\nNo complete final answer was found. Forcing final answer:")
    print(forced)


Expanding depth 0 state:
Solve the integral: \( \int_0^1 x^2 \, dx \)

Added child: Node(depth=1, score=10.0, state='Okay, so I need to fix the solution attempt for...')
Node(depth=0, score=None, state='Solve the integral: \( \int_0^1 x^2 \, dx \)')
  Node(depth=1, score=10.0, state='Okay, so I need to fix the solution attempt for...')
---

No complete final answer was found. Forcing final answer:
\boxed{1/3}


# Evaluation of the Tree of Thoughts (ToT) Method on the Math500 Dataset

This evaluation framework is designed to test the ToT method on a subset of the Math500 dataset. It provides flexibility in hyperparameter configuration and is set up for both debugging with detailed output and large-scale evaluation. The framework saves results and summarizes key metrics, and it includes mechanisms to force the model to output a final answer in the expected format.

## Key Features

- **Unique Results File:**  
  Uses a dedicated results file (e.g., `evaluation_results_tot_test.json`) to store evaluation data. The file is cleared at the beginning of each run to prevent interference from previous evaluations.

- **Configurable Hyperparameters:**  
  You can adjust parameters such as:
  - `num_candidates`: Number of candidate final answers generated per node.
  - `top_n`: Number of top candidates selected for expanding the tree.
  - `max_depth`: Maximum depth of the search tree.
  
  These parameters enable you to test different reasoning strategies and trade-offs between search breadth and depth.

- **Sample Selection:**  
  The `max_samples` parameter controls the number of problems from the dataset to evaluate. This allows you to start by testing on a single sample and then scale up to 100 or even more problems as desired.

- **Debug and Fallback Mechanisms:**  
  - The framework prints the entire TOT tree for each problem to facilitate inspection.
  - If no node contains a complete final answer (i.e., a final answer wrapped in `\boxed{...}`), a fallback prompt is issued to force the model to output the final answer.
  - Detailed debug outputs help track the process and diagnose any issues in final answer extraction.

- **Final Answer Extraction:**  
  After processing the tree, the system extracts just the final answer from the model's output (using a helper function like `extract_answer`), ensuring that only the final answer (and no chain-of-thought explanations) is compared against the correct answer.

- **Result Saving and Analysis:**  
  The framework saves each problem’s evaluation (including problem text, responses, and correctness) and produces a summary report that includes metrics like total problems, correct answers, and overall accuracy.

## Encouragement for Further Improvement

**Prompt Engineering & Fallback Strategies:**  
The current method forces the model to provide a final answer using a fallback prompt when the initial generation does not meet the required format. While this approach works, it is not perfect:
- **Prompt Tuning:** Experiment with different wording and structure in the prompts. For example, try different phrasings that emphasize "ONLY your final answer" and "no additional explanations" to see if the model can be nudged into generating a cleaner response.
- **Iterative Refinement:** Consider implementing iterative prompt refinement mechanisms or leveraging additional post-processing steps to filter out unwanted chain-of-thought text.
- **Open Research Problem:** The issue of controlling a language model’s output to include only the final answer (and not intermediate reasoning) is an active area of research. There is significant potential to explore improved strategies that maintain reasoning power while enforcing output constraints.

**Scalability:**  
Once you have fine-tuned the hyperparameters and the prompting strategy on a small sample, encourage testing over a larger set (such as 100 or more problems). Evaluating on a larger dataset can help identify trends and potential improvements that might not be apparent on a smaller scale.

**Experiment and Innovate:**  
Do not hesitate to modify the prompts, fallback mechanisms, and even the underlying structure of the TOT class. Every change you experiment with might lead to better results and a deeper understanding of how to steer the model toward producing just the final answer. Your experimentation is key to achieving a more robust and reliable evaluation system.

This framework is designed to be flexible—feel free to tweak the parameters and strategies to suit your research or production needs, and continue to iterate toward better performance!


In [None]:
import os
from tqdm import tqdm

def evaluate_tot(max_samples: int = 1):
    """
    Evaluate the Tree of Thoughts (ToT) method on a subset of the Math500 dataset.
    Uses GPT/Gemini exclusively for verification and fallback.
    Assumes the existence of helper functions:
      - load_math500_dataset()
      - extract_answer(solution_text)
      - compare_answers(correct_answer, predicted_answer)
      - save_result(results_file, result_dict)
      - load_existing_results(results_file)
      - analyze_results(results_list)
      - get_llm_response(prompt)     # for initial candidate generation
      - get_api_response(prompt)     # for GPT/Gemini–based verification
    """
    os.makedirs("results", exist_ok=True)
    results_file = "results/evaluation_results_tot_test.json"
    if os.path.exists(results_file):
        os.remove(results_file)

    dataset = load_math500_dataset()
    existing = load_existing_results(results_file)
    processed = {r['index'] for r in existing}

    tot = TreeOfThoughts(num_candidates=1, top_n=1, max_depth=1, verbose=True)
    correct_count = 0
    evaluated = 0

    def collect_all_nodes(node):
        nodes = [node]
        for child in node.children:
            nodes.extend(collect_all_nodes(child))
        return nodes

    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed or evaluated >= max_samples:
            continue

        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])

        # Run the Tree‑of‑Thoughts search
        tot_tree = tot.search(problem_text)

        # Debug print of the entire tree
        print(f"\n--- DEBUG: Full TOT tree for problem index {idx} ---")
        def print_tree(node, indent=""):
            print(indent + repr(node))
            for c in node.children:
                print_tree(c, indent + "  ")
        print_tree(tot_tree)
        print("--- End of TOT tree ---\n")

        # Find first node with a boxed answer
        all_nodes = collect_all_nodes(tot_tree)
        solution_node = next((n for n in all_nodes if tot.is_solution(n)), None)

        if solution_node:
            response = solution_node.state
            print("DEBUG: Found solution node:", response)
        else:
            # Fallback: ask GPT/Gemini directly for final answer
            print("DEBUG: No solution node found. Using GPT fallback.")
            fallback_prompt = (
                f"Based on the problem: \"{problem_text}\", provide ONLY your final answer "
                "in the exact format \\boxed{<final answer>}. Do not include any chain-of-thought."
            )
            response = get_api_response(fallback_prompt)
            print("DEBUG: Fallback response:", response)

        predicted = extract_answer(response) or ""
        if not predicted:
            print("DEBUG: predicted_answer is empty. Raw response was:\n", response)

        is_correct = compare_answers(correct_answer, predicted)
        save_result(results_file, {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted,
            "is_correct": is_correct
        })

        if is_correct:
            correct_count += 1
        evaluated += 1
        print(f"Correct: {correct_count}/{evaluated} | Index: {idx}")

    final = load_existing_results(results_file)
    analyze_results(final)

# Example: test a single sample
evaluate_tot(max_samples=1)

Evaluating problems:   0%|          | 0/500 [00:00<?, ?it/s]


Expanding depth 0 state:
Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$



Evaluating problems: 100%|██████████| 500/500 [00:25<00:00, 19.31it/s]

Added child: Node(depth=1, score=10.0, state='To convert the rectangular coordinate (0,3) to ...')

--- DEBUG: Full TOT tree for problem index 0 ---
Node(depth=0, score=None, state='Convert the point $(0,3)$ in rectangular coordi...')
  Node(depth=1, score=10.0, state='To convert the rectangular coordinate (0,3) to ...')
--- End of TOT tree ---

DEBUG: Found solution node: To convert the rectangular coordinate (0,3) to polar coordinates, I'll follow these steps:

1. **Calculate the radius (r):**
   - Use the formula \( r = \sqrt{x^2 + y^2} \).
   - Substituting the values: \( r = \sqrt{0^2 + 3^2} = \sqrt{9} = 3 \).

2. **Determine the angle (θ):**
   - Use the formula \( θ = \arctan\left(\frac{y}{x}\right) \).
   - Since x is 0, this becomes \( θ = \arctan\left(\frac{3}{0}\right) \).
   - Recognizing that dividing by zero is undefined, I know that when x is 0, the point lies on the y-axis.
   - Since y is positive, θ is \( \frac{\pi}{2} \) radians or 90 degrees.

Therefore, the polar c




# A* Search Algorithm for Mathematical Reasoning

The A* (A-star) search algorithm is a popular informed search method that combines both actual cost and an estimated cost to reach the goal. When applied to mathematical problem-solving with language models, each node in the search tree represents a partial reasoning process. The goal is to guide the search toward the correct final answer efficiently.

## Core Concepts

- **Nodes and States:**  
  Each node represents a state in the reasoning process—a partial solution or chain-of-thought step. The root node is the initial problem, and child nodes represent possible next steps in reasoning.

- **Cost Function (g(n)):**  
  This function measures the cost accumulated from the start node to the current node. In our context, it might represent the complexity or length of the reasoning chain so far.

- **Heuristic Function (h(n)):**  
  A heuristic estimates the cost (or “distance”) from the current node to the goal (a complete final answer). For mathematical reasoning, this could be designed to reflect how promising the current partial solution is—possibly by prompting the LLM to provide a confidence or plausibility score.

- **Evaluation Function (f(n)):**  
  A* uses the function:
  \[
  f(n) = g(n) + h(n)
  \]
  to choose which node to expand next. Nodes with lower f(n) are expanded first, steering the search toward the most promising reasoning paths.

## How A* Works in Mathematical Reasoning

1. **Initialization:**  
   The algorithm starts with the initial problem as the root node, with an initial cost \(g(n)=0\).

2. **Expansion:**  
   From the current node, the model generates several potential next steps (child nodes). Each child node represents a possible continuation of the reasoning process.

3. **Cost and Heuristic Calculation:**  
   - **g(n):** Represents the cost accumulated so far (e.g., the number of reasoning steps taken).
   - **h(n):** An estimate of how “far” the current state is from a complete final answer. This can be derived via LLM-based evaluations or comparisons to known correct patterns.

4. **Priority Queue and Node Selection:**  
   The algorithm uses a priority queue to maintain nodes, sorted by their \( f(n) \) value. The node with the smallest \(f(n)\) (i.e., the most promising combination of current cost and estimated remaining cost) is expanded next.

5. **Goal Test:**  
   The process continues until a node is found that meets the goal—a node whose state contains a complete final answer in the expected format (e.g., a LaTeX expression wrapped in `\boxed{...}`).

## Challenges in Applying A* to Reasoning

- **Heuristic Design:**  
  Defining an effective h(n) is challenging. The heuristic must correlate well with the true “distance” to a correct final answer. For language models, this might involve model confidence scores or custom prompt evaluations.

- **Balancing Exploration and Exploitation:**  
  Overemphasis on g(n) might favor shorter, less-complete reasoning chains, while too much reliance on h(n) might cause the search to overestimate the quality of partially correct answers.

- **Computational Expense:**  
  Evaluating each node’s heuristic (potentially via additional LLM queries) can be computationally expensive, especially in a large search space.

- **Scalability:**  
  The state space for reasoning is vast. A well-tuned A* algorithm can efficiently prune irrelevant paths, but without a robust heuristic, the number of nodes to explore can grow exponentially.



Next, we will move on to the code implementation of the A* algorithm for our reasoning framework.


In [None]:
import re
import time
import heapq
from typing import List, Optional


class AStarNode:
    def __init__(self, state: str, g: float = 0.0, parent: Optional['AStarNode'] = None):
        self.state = state                      # easy: store the current problem or partial answer
        self.g = g                              # easy: cost so far (depth)
        self.h: float = 0.0                     # will be set by heuristic evaluator
        self.f: float = 0.0                     # f = g + h
        self.parent = parent                    # link back to parent for solution path
        self.children: List[AStarNode] = []     # easy: list to hold generated successors

    def __lt__(self, other: 'AStarNode'):
        return self.f < other.f                 # easy: allow heapq to compare nodes by f‑value

    def __repr__(self):
        preview = self.state if len(self.state) < 50 else self.state[:47] + "..."
        return f"AStarNode(g={self.g}, h={self.h}, f={self.f}, state='{preview}')"


GEN_SYSTEM_PROMPT = """
You are a mathematical problem solver.

Problem:
{node}

Generate {num_candidates} possible next steps or solutions.
Each should end with a final answer in the LaTeX boxed format \\boxed{{...}}.
Provide each solution on a new line.
"""

EVAL_PROMPT = """
You are a precise mathematics grader.

Original problem:
{root_problem}

Current reasoning state: {state}

On a scale from 0 (complete final answer in \\boxed{{...}} format) to 10
(completely unrelated to the problem), how incomplete is this reasoning step?
Reply with only a number between 0 and 10.
"""


RETRY_PROMPT = """
f"Problem:
{root_problem}

Please provide ONLY your final answer in the exact format \\boxed{{...}}
with no additional explanation.
"""

class AStarSearch:
    def __init__(
        self,
        num_candidates: int = 3,
        max_depth: int = 3,
        verbose: bool = False,
        max_fallback: int = 3
    ):
        self.num_candidates = num_candidates    # how many answers to generate per node
        self.max_depth = max_depth              # search will stop at this depth
        self.verbose = verbose                  # if True, print debug info
        self.max_fallback = max_fallback        # retries for enforcing boxed format
        self.root_problem: Optional[str] = None  # will hold the original problem text

    def is_solution(self, state: str) -> bool:
        # TODO: return True if `extract_answer(state)` yields a non‑empty string
        return extract_answer(state) is not None

    def generate_candidates(self, node: AStarNode) -> List[str]:
        """
        Ask the LLM for `num_candidates` boxed answers to node.state.
        """
        # TODO: compose a prompt with node.state asking for LaTeX \\boxed{...} answers
        prompt = GEN_SYSTEM_PROMPT.format(node=node.state, num_candidates=self.num_candidates)

        # TODO: call `get_llm_response(prompt)`
        response = get_llm_response(prompt)

        # TODO: split the response into non‑empty lines
        candidates = [line.strip() for line in response.split('\n') if line.strip()]

        # TODO: if fewer than `self.num_candidates`, replicate lines to reach that count
        # TODO: return exactly `self.num_candidates` answer strings
        if len(candidates) < self.num_candidates:
            candidates += [candidates[-1]] * (self.num_candidates - len(candidates))
        return candidates[:self.num_candidates]


    def evaluate_heuristic(self, state: str) -> float:
        """
        Use GPT (via get_api_response) to score how incomplete a candidate is:
        0 means perfect boxed answer; higher means more incomplete.
        """
        assert self.root_problem is not None, "Root problem must be set before heuristic evaluation"
        # TODO: build `eval_prompt` using self.root_problem and the current `state`
        eval_prompt = EVAL_PROMPT.format(root_problem=self.root_problem, state=state)

        # TODO: call `get_api_response(eval_prompt).strip()`
        resp = get_api_response(eval_prompt).strip()

        # TODO: extract the first numeric value with `re.search`
        match = re.search(r'\d+(?:\.\d+)?', resp)

        # TODO: return that float, or a fallback like `self.max_depth * 10` if parsing fails
        return max(0.0, min(10.0, float(match.group()))) if match else self.max_depth * 10


    def expand_node(self, node: AStarNode) -> List[AStarNode]:
        if self.verbose:
            print(f"\nExpanding node at depth {node.g}:\n{node.state}\n")

        if self.root_problem is None:
            self.root_problem = node.state
        candidates = self.generate_candidates(node)
        children: List[AStarNode] = []
        # TODO: for each `cand` in candidates:
        for cand in candidates:
            attempts = 0
            #      - while not self.is_solution(cand) and attempts < self.max_fallback:
            #            * send strict fallback prompt via get_llm_response
            #            * cand = response.strip()
            #            * attempts += 1
            #            * if verbose: print fallback info
            while not self.is_solution(cand) and attempts < self.max_fallback:
                fallback_prompt = RETRY_PROMPT.format(root_problem = self.root_problem)
                cand = get_llm_response(fallback_prompt).strip()
                attempts += 1
                if self.verbose: print(f"Fallback attempt {attempts}: {cand}")

            child = AStarNode(state=cand, g=node.g + 1, parent=node)
            child.h = self.evaluate_heuristic(child.state)
            child.f = child.g + child.h
            node.children.append(child)
            children.append(child)

            if self.verbose:
                print(f"Generated child: {child}")
        return children

    def search(self, initial_problem: str) -> Optional[AStarNode]:
        """
        Run A* until a boxed solution is found or max_depth is exceeded.
        """
        self.root_problem = initial_problem
        root = AStarNode(state=initial_problem, g=0.0)
        root.h = self.evaluate_heuristic(root.state)
        root.f = root.g + root.h

        frontier: List[AStarNode] = []
        heapq.heappush(frontier, root)

        while frontier:
            node = heapq.heappop(frontier)
            if self.verbose:
                print(f"Expanding: {node}")

            if self.is_solution(node.state):
                return node

            if node.g >= self.max_depth:
                continue

            for child in self.expand_node(node):
                heapq.heappush(frontier, child)

        # TODO: if no solution found, return None
        return None


# Test Code for A* Search Method with Minimal Hyperparameters

Below is a description of the test procedure for the A* search method using minimal hyperparameters. This test ensures that the algorithm returns only the final answer in the proper format, without extra chain-of-thought text.

- **Initialization:**
  - An A* search instance is created with the following settings:
    - `num_candidates`: 1 (only one candidate is generated per node)
    - `max_depth`: 1 (the search tree is kept shallow for testing)
    - `verbose`: True (detailed debug information is printed)
    - `max_fallback`: 3 (up to three fallback attempts are made to force a final answer)
  - The initial problem is set as:
    - "Solve the integral: \( \int_0^1 x^2 \, dx \)"

- **Search Execution:**
  - The A* search is executed on the initial problem to obtain a solution node.
  
- **Tree Printing:**
  - A recursive function (e.g., `print_astar_tree`) is used to print the entire A* search tree, starting from the root. This allows inspection of all nodes and the reasoning process.

- **Final Answer Extraction:**
  - If a solution node is found, the algorithm backtracks to the root to print the entire tree.
  - Then it extracts the final answer from the solution node using a helper function (e.g., `extract_answer`). The state of the solution node is reformatted to display only the final answer in the exact format (e.g., `\boxed{<final answer>}`).

- **Fallback Handling:**
  - If no solution node is found, a fallback prompt is issued. This prompt instructs the model to provide ONLY its final answer in the correct format, with no extra explanation.
  - The fallback final answer is then printed.

This test code is designed to verify that, with minimal hyperparameters, the A* search method consistently returns a complete final answer, ensuring the output is directly comparable with the expected result.


In [None]:
# Test Code for A* Search Method with Minimal Hyperparameters

# Initialize the A* search instance with minimal settings.
# Candidate generation still uses get_llm_response(...),
# but all verification and final fallback use get_api_response(...) (Gemini).
astar = AStarSearch(
    num_candidates=1,
    max_depth=1,
    verbose=True,
    max_fallback=3
)

initial_problem = "Solve the integral: \\( \\int_0^1 x^2 \\, dx \\)"
solution_node = astar.search(initial_problem)

# Recursive printer for the A* search tree.
def print_astar_tree(node, indent=""):
    print(indent + repr(node))
    for child in node.children:
        print_astar_tree(child, indent + "  ")

if solution_node:
    # Backtrack to the root.
    root = solution_node
    while root.parent is not None:
        root = root.parent

    # Print the entire tree from the root.
    print_astar_tree(root)

    # Extract and normalize the final answer.
    final = extract_answer(solution_node.state)
    if final:
        solution_node.state = f"\\boxed{{{final}}}"
    print("\nFinal Answer Found:")
    print(solution_node.state)
else:
    # No solution found: use GPT/Gemini directly for the final answer.
    fallback_prompt = (
        f"Based on the problem: \"{initial_problem}\", please provide ONLY your final answer "
        "in the exact format \\boxed{<final answer>}. Do not include any chain-of-thought or extra explanation."
    )
    forced_final_answer = get_api_response(fallback_prompt).strip()
    final = extract_answer(forced_final_answer) or forced_final_answer
    forced_final_answer = f"\\boxed{{{final}}}"
    print("\nNo complete solution node found. Forced final answer:")
    print(forced_final_answer)


Expanding: AStarNode(g=0.0, h=10.0, f=10.0, state='Solve the integral: \( \int_0^1 x^2 \, dx \)')

Expanding node at depth 0.0:
Solve the integral: \( \int_0^1 x^2 \, dx \)

Fallback attempt 1: Okay, so I need to solve the integral of x squared from 0 to 1. Hmm, let's see. I remember that integrals are about finding the area under a curve, right? In this case, the curve is y = x squared, and we're looking from x=0 to x=1. 

First, I should recall the basic rules of integration. I think the integral of x^n dx is (x^(n+1))/(n+1) + C, where C is the constant of integration. But wait, since we're doing a definite integral from 0 to 1, I don't need the constant. 

So, applying that formula to x squared. Here, n is 2 because it's x^2. So the integral should be (x^(2+1))/(2+1) = x^3 / 3. 

Now, I need to evaluate this from 0 to 1. That means I plug in the upper limit, which is 1, into the antiderivative, and then subtract the value of the antiderivative at the lower limit, which is 0. 

Let m

# Evaluation of the A* Search Method on the Math500 Dataset

This evaluation framework is designed to test the A* search method on a subset of the Math500 dataset. It provides flexibility in hyperparameter configuration and is set up for both detailed output and large-scale evaluation. The framework saves results, summarizes key metrics, and includes mechanisms to force the model to output a final answer in the expected format.

## Key Features

- **Unique Results File:**  
  Uses a dedicated results file (e.g., `evaluation_results_astar_test.json`) to store evaluation data. The file is cleared at the start of each run to prevent interference from previous evaluations.

- **Configurable Hyperparameters:**  
  You can adjust parameters such as:
  - `num_candidates`: Number of candidate final answers generated per node.
  - `max_depth`: Maximum depth of the search tree.
  
  These settings enable you to test different reasoning strategies and trade-offs between search exploration depth and the precision of the final answer.

- **Sample Selection:**  
  The `max_samples` parameter controls the number of problems from the dataset to evaluate. This allows you to start by testing on a single sample and then scale up to 100 or more problems as desired.

- **Fallback Mechanisms:**  
  If no node in the A* search tree contains a complete final answer (i.e., one wrapped in `\boxed{...}`), a fallback prompt is issued to force the model to output only the final answer. This guarantees that every evaluated problem produces a final answer in the proper format.

- **Final Answer Extraction:**  
  After processing the search tree, the framework extracts just the final answer from the model’s output (using a helper function like `extract_answer`). This ensures that only the final answer is compared against the expected solution, without any additional chain-of-thought text.

- **Result Saving and Analysis:**  
  The framework saves detailed evaluation data—including the problem text, the model's raw response, the extracted final answer, and correctness—and produces a summary report with metrics such as total problems evaluated, number of correct answers, and overall accuracy.

## Encouragement for Further Improvement

**Prompt Engineering & Fallback Strategies:**  
The current method forces the model to provide a final answer using a fallback prompt if the initial search does not yield the required format. Experiment with:
- **Prompt Tuning:** Adjust the wording and structure to further emphasize "ONLY your final answer" and "no extra text."
- **Iterative Refinement:** Consider iterative prompt refinement or additional post-processing to isolate the final answer more reliably.
- **Innovative Approaches:** This challenge of extracting only the final answer is an active area of research. Exploring new strategies may lead to better performance and more robust results.

**Scalability:**  
Once the hyperparameters and prompting strategy are fine-tuned on a small set of problems, scale the evaluation to a larger sample (e.g., 100+ problems). A broader evaluation can reveal trends and help identify further improvements.

**Experiment and Innovate:**  
Feel free to modify prompts, adjust hyperparameters, and refine fallback mechanisms. Comparing the A* search method with other approaches, such as the Tree of Thoughts method, can provide valuable insights. Your experimentation is key to developing a more robust and reliable evaluation system.

This flexible framework is designed to meet diverse research and production needs—keep iterating and exploring until you achieve optimal results!


In [None]:
import os
import random
from tqdm import tqdm

def evaluate_astar_random(max_samples: int = 1):
    """
    Evaluate the A* search method on randomly selected problems from the Math500 dataset.
    Uses GPT/Gemini (via get_api_response) for any fallback final-answer requests.
    Assumes existence of:
      - load_math500_dataset()
      - extract_answer(solution_text)
      - compare_answers(correct_answer, predicted_answer)
      - save_result(results_file, result_dict)
      - load_existing_results(results_file)
      - analyze_results(results_list)
      - get_llm_response(prompt)   # for candidate generation
      - get_api_response(prompt)   # for GPT/Gemini–based fallback
    """
    os.makedirs("results", exist_ok=True)
    results_file = "results/evaluation_results_astar_random_test.json"

    # Clear previous test results
    if os.path.exists(results_file):
        os.remove(results_file)

    dataset = load_math500_dataset()
    total = len(dataset)

    # Pick random unique indices
    selected = set()
    while len(selected) < max_samples:
        selected.add(random.randint(0, total - 1))

    astar = AStarSearch(num_candidates=1, max_depth=1, verbose=True, max_fallback=3)
    correct_count = 0
    evaluated = 0

    def collect_all_nodes(node):
        nodes = [node]
        for c in node.children:
            nodes.extend(collect_all_nodes(c))
        return nodes

    for idx in selected:
        item = dataset[idx]
        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])

        solution_node = astar.search(problem_text)
        if solution_node:
            response = solution_node.state
            print("DEBUG: Found solution node:", response)
        else:
            print("DEBUG: No solution node found. Using GPT fallback.")
            fallback_prompt = (
                f"Based on the problem: \"{problem_text}\", provide ONLY your final answer "
                "in the exact format \\boxed{<final answer>}. Do not include any chain-of-thought or explanation."
            )
            response = get_api_response(fallback_prompt).strip()
            print("DEBUG: Fallback response:", response)

        predicted_answer = extract_answer(response)
        if not predicted_answer:
            print("DEBUG: predicted_answer is empty. Raw response was:\n", response)

        is_correct = compare_answers(correct_answer, predicted_answer or "")

        save_result(results_file, {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        })

        if is_correct:
            correct_count += 1
        evaluated += 1
        print(f"Correct: {correct_count}/{evaluated} | Index: {idx}")

    final_results = load_existing_results(results_file)
    analyze_results(final_results)

# Example usage:
evaluate_astar_random(max_samples=1)
# To test with different parameters, adjust max_samples.


Expanding: AStarNode(g=0.0, h=10.0, f=10.0, state='The expression $10x^2-x-24$ can be written as $...')

Expanding node at depth 0.0:
The expression $10x^2-x-24$ can be written as $(Ax-8)(Bx+3),$ where $A$ and $B$ are integers. What is $AB + B$?

Fallback attempt 1: Okay, so I have this problem where I need to express the quadratic expression \(10x^2 - x - 24\) as a product of two binomials in the form \((Ax - 8)(Bx + 3)\), where \(A\) and \(B\) are integers. Then, I have to find the value of \(AB + B\). Hmm, okay, let me break this down step by step.

First, I remember that when you multiply two binomials, you can use the FOIL method to expand them and then compare the coefficients with the original expression. The original expression is \(10x^2 - x - 24\), and the target form is \((Ax - 8)(Bx + 3)\). So, let me write that out using the FOIL method.

FOIL stands for First, Outer, Inner, Last. That means I'll multiply the first terms, then the outer terms, then the inner terms, and fin

# Monte Carlo Tree Search (MCTS) for Mathematical Reasoning

Monte Carlo Tree Search (MCTS) is a probabilistic search algorithm particularly well-suited for decision-making tasks in large, complex search spaces. In the context of mathematical problem solving with language models, each node in the search tree represents a partial reasoning step. MCTS incrementally builds the tree by exploring the most promising reasoning paths through random simulations, balancing exploration of new paths with exploitation of those that appear promising.

## Core Components of MCTS

MCTS operates in four main stages:

1. **Selection:**  
   Starting at the root (the initial problem), the algorithm traverses the tree by selecting child nodes based on a policy that balances two factors:  
   - **Exploitation:** Favoring nodes that have already shown high promise (i.e., those with a good reward or low cost).  
   - **Exploration:** Giving a chance to less-visited nodes to discover potentially promising new paths.  
   This balance is often managed by a criterion such as the Upper Confidence Bound (UCB).

2. **Expansion:**  
   When a leaf node is reached—one that has not been fully expanded—the algorithm expands it by generating one or more child nodes. Each new node represents a possible next step in the reasoning process, such as a candidate final answer generated by the language model.

3. **Simulation (Rollout):**  
   From the newly expanded node, the algorithm performs a simulation (or rollout) to estimate the outcome if that reasoning path were followed to completion. For mathematical reasoning, this may involve prompting the language model to complete the remaining reasoning and produce a final answer. The outcome of the simulation provides an estimated reward or cost for that path.

4. **Backpropagation:**  
   The result of the simulation is then propagated back up the tree. Each node along the path has its evaluation updated based on the outcome, which in turn refines the selection policy for future iterations. This backpropagation ensures that nodes contributing to more promising outcomes are prioritized.

## How MCTS Applies to Mathematical Reasoning

- **State Representation:**  
  Each node encapsulates a partial solution or chain-of-thought produced by the language model. The goal is to eventually obtain a final answer formatted in a concise manner (e.g., wrapped in `\boxed{...}`).

- **Reward and Heuristic:**  
  The reward signal derived from the simulation reflects how close a reasoning path is to a correct final answer. The model's evaluations—such as plausibility scores—help guide the search by penalizing incomplete or incorrect paths.

- **Balancing Exploration and Exploitation:**  
  MCTS effectively manages the trade-off between exploring new, untested reasoning paths and exploiting those paths that have already demonstrated potential. This balance is critical in navigating the vast space of possible reasoning steps.

## Challenges and Considerations

- **Simulation Cost:**  
  Running multiple simulations per node can be computationally intensive, especially when each simulation involves multiple language model calls.

- **Reward Signal Design:**  
  Defining an accurate and meaningful reward (or evaluation) for partial solutions is challenging. The reward must correlate well with the likelihood of ultimately arriving at a complete and correct final answer.

- **Parameter Tuning:**  
  Effective implementation of MCTS requires careful tuning of parameters like the exploration constant and the number of simulations per node. This tuning is essential to balance the search effectively.

- **Ensuring Concise Final Answers:**  
  One of the key objectives is to force the model to output only the final answer (without additional chain-of-thought text). This requires precise prompt engineering and robust fallback strategies in both candidate generation and simulation phases.

## Conclusion

MCTS provides a powerful framework for exploring the reasoning process in language models by combining random simulation with informed backpropagation. Through iterative exploration of promising reasoning paths and careful balancing of exploration and exploitation, MCTS aims to identify the most promising route to a final answer. With a strong focus on ensuring that only a concise final answer is produced (wrapped in a format like `\boxed{...}`), this method offers significant potential for enhancing mathematical problem solving.

By adjusting hyperparameters and refining the simulation and evaluation processes, you can experiment with different configurations to improve the efficiency and accuracy of the final answers. This makes MCTS a flexible and promising approach for further research and practical applications in guided reasoning with language models.


In [None]:
import re
import time
import math
import heapq
from typing import List, Optional


def evaluate_with_gpt(problem: str, candidate: str) -> float:
    """
    Ask the verifier LLM (via get_api_response) whether the boxed answer is correct.
    Returns 1.0 for “yes”, 0.0 for “no”, or an intermediate score if provided.
    """
    prompt = (
        "You are a precise math grader.\n\n"
        f"Problem:\n\"{problem}\"\n\n"
        "Final answer to check (in \\boxed{...} format):\n"
        f"\"{candidate}\"\n\n"
        "Is this answer correct? Reply with a number between 0 (incorrect) and 1 (fully correct)."
    )
    # TODO: Send `prompt` to `get_api_response`, strip whitespace
    resp = get_api_response(prompt).strip()

    # TODO: Extract the first numeric match with `re.search`
    match = re.search(r'0(\.\d+)?|1(\.0+)?', resp)

    # TODO: Convert to float (fallback to 0.0 on failure)
    # TODO: Clamp the result to [0.0, 1.0] and return it
    score = float(match.group(0)) if match else 0.0
    return max(0.0, min(1.0, score))


class MCTSNode:
    def __init__(self, state: str, parent: Optional['MCTSNode'] = None):
        self.state = state
        self.parent = parent
        self.children: List[MCTSNode] = []
        self.visits = 0
        self.total_reward = 0.0

    def add_child(self, child: 'MCTSNode'):
        self.children.append(child)

    def __repr__(self):
        preview = self.state if len(self.state) < 50 else self.state[:47] + "..."
        return f"MCTSNode(visits={self.visits}, reward={self.total_reward:.2f}, state='{preview}')"




class MCTSSearch:
    def __init__(
        self,
        num_simulations: int = 10,
        exploration_const: float = 1.41,
        max_depth: int = 3,
        max_fallback: int = 3,
        num_candidates: int = 1,
        verbose: bool = False
    ):
        self.num_simulations = num_simulations
        self.exploration_const = exploration_const
        self.max_depth = max_depth
        self.max_fallback = max_fallback
        self.num_candidates = num_candidates
        self.verbose = verbose
        self.root_problem: Optional[str] = None

    def is_solution(self, state: str) -> bool:
        # TODO: return True if `extract_answer(state)` yields a non-empty boxed answer
        return extract_answer(state) is not None

    def generate_candidates(self, state: str) -> List[str]:
        # TODO: build a prompt asking for `self.num_candidates` LaTeX \\boxed{...} answers to `state`
        prompt = (
            f"Given the problem or partial solution: {state}\n\n"
            f"Provide {self.num_candidates} possible complete solutions, each with a boxed final answer. "
            "Provide each solution on a separate line."
        )

        # TODO: call `get_llm_response(prompt)`
        response = get_llm_response(prompt)

        # TODO: split response into non-empty lines
        candidates = [line.strip() for line in response.split("\n") if line.strip()]

        # TODO: if fewer than `self.num_candidates`, replicate lines to match
        # TODO: return exactly `self.num_candidates` strings
        if len(candidates) < self.num_candidates:
            candidates += [candidates[-1]] * (self.num_candidates - len(candidates))
        return candidates[:self.num_candidates]

    def expand(self, node: MCTSNode) -> List[MCTSNode]:
        if self.verbose:
            print(f"\nExpanding node (depth {self._depth(node)}):\n{node.state}\n")

        # ensure we remember the original problem
        if self.root_problem is None:
            self.root_problem = node.state

        # TODO: candidates = self.generate_candidates(node.state)
        candidates = self.generate_candidates(node.state)
        children: List[MCTSNode] = []

        # TODO: for each `cand` in `candidates`:
        for cand in candidates:
            final = cand
            attempts = 0
            while not self.is_solution(final) and attempts < self.max_fallback:
                fallback_prompt = (
                    f"The following solution is incomplete or incorrect: {final}\n\n"
                    f"Please provide a correct, complete solution with a boxed final answer for the problem: {self.root_problem} \n\n"
                    "Please provide ONLY your final answer in the exact format \\boxed{...} "
                    "with no additional explanation."
                )
                # send strict fallback prompt via `get_llm_response`
                # update `final = response.strip()`
                final = get_llm_response(fallback_prompt).strip()
                attempts += 1
                if self.verbose: print(f"Fallback attempt {attempts}: {final}")

            child = MCTSNode(state=final, parent=node)
            node.add_child(child)
            children.append(child)
        return children


    def simulate(self, node: MCTSNode) -> float:
        """
        Instead of a blind rollout, directly evaluate the final boxed answer.
        """
        # TODO: if `self.root_problem` is None, return 0.0
        if self.root_problem is None:
            return 0.0

        # TODO: if not self.is_solution(node.state):
        #         * send fallback prompt to `get_llm_response`
        #         * node.state = response.strip()
        if not self.is_solution(node.state):
            fallback_prompt = (
                f"Problem: {self.root_problem}\n\n"
                "Please provide ONLY your final answer in the exact format \\boxed{...} "
                "with no additional explanation."
            )
            node.state = get_llm_response(fallback_prompt).strip()

        # TODO: return evaluate_with_gpt(self.root_problem, node.state)
        return evaluate_with_gpt(self.root_problem, node.state)


    def ucb_score(self, child: MCTSNode, parent_visits: int) -> float:
        # TODO: if child.visits == 0: return float('inf')
        if child.visits == 0: return float('inf')

        exploit = child.total_reward / child.visits
        explore = self.exploration_const * math.sqrt(math.log(parent_visits) / child.visits)
        return exploit + explore


    def select(self, root: MCTSNode) -> MCTSNode:
        # TODO: starting at `root`, repeatedly pick the child with highest `ucb_score`
        #       until you reach a node with no children, then return it
        node = root
        while node.children:
            best_child = max(node.children, key=lambda c: self.ucb_score(c, node.visits))
            node = best_child
        return node

    def backpropagate(self, node: MCTSNode, reward: float):
        while node is not None:
            node.visits += 1
            node.total_reward += reward
            node = node.parent


    def _depth(self, node: MCTSNode) -> int:
        d = 0
        while node.parent:
            d += 1
            node = node.parent
        return d

    def search(self, initial_problem: str) -> Optional[MCTSNode]:
        """
        Perform MCTS using GPT as the verifier:
        - Selection: `select`
        - Expansion: `expand`
        - Simulation/Backprop: `simulate` + `backpropagate`
        """
        # TODO: set `self.root_problem = initial_problem`
        # TODO: create `root = MCTSNode(state=initial_problem)`.
        self.root_problem = initial_problem
        root = MCTSNode(state=initial_problem)

        # TODO: for `_` in range(self.num_simulations):
        for _ in range(self.num_simulations):
            leaf = self.select(root)
            if self._depth(leaf) >= self.max_depth:
                reward = self.simulate(leaf)
                self.backpropagate(leaf, reward)
                continue

            children = self.expand(leaf)
            if not children: continue

            reward = self.simulate(children[0])
            self.backpropagate(children[0], reward)

        return max(root.children, key=lambda c: c.total_reward / c.visits if c.visits > 0 else 0) if root.children else root


# Test Code for MCTS Search Method with Minimal Hyperparameters

Below is a description of the test procedure for the MCTS search method using minimal hyperparameters. This test ensures that the algorithm returns only the final answer in the proper format, without extra chain-of-thought text.

- **Initialization:**
  - An MCTS search instance is created with the following settings:
    - `num_simulations`: 5 (number of simulations to run for exploring reasoning paths)
    - `exploration_const`: 1.41 (parameter to balance exploration and exploitation)
    - `max_depth`: 1 (the search tree is kept shallow for testing)
    - `max_fallback`: 3 (up to three fallback attempts are made to force a final answer)
    - `num_candidates`: 1 (only one candidate is generated per node)
    - `verbose`: True (detailed debug information is printed)
  - The initial problem is set as:
    - "Solve the integral: \( \int_0^1 x^2 \, dx \)"

- **Search Execution:**
  - The MCTS search is executed on the initial problem to obtain a solution node.

- **Tree Printing:**
  - A recursive function (e.g., `print_mcts_tree`) prints the entire MCTS search tree starting from the root. This allows inspection of all nodes and the reasoning process.

- **Final Answer Extraction:**
  - If a solution node is found, the algorithm backtracks to the root and prints the full tree.
  - The final answer is then extracted from the solution node using a helper function (e.g., `extract_answer`) and reformatted to display only the final answer in the exact format (e.g., `\boxed{<final answer>}`), ensuring that no extra explanation or chain-of-thought text is present.

- **Fallback Handling:**
  - If no solution node is found, a fallback prompt is issued that instructs the model to provide ONLY the final answer in the correct format, with no additional explanation.
  - The fallback final answer is printed.

This test code is designed to verify that, with minimal hyperparameters, the MCTS search method consistently returns a complete final answer. Users are encouraged to experiment with these hyperparameters and refine the prompts to further improve performance and compare the results with other methods.


In [None]:
# Test Code for MCTS Search Method with Minimal Hyperparameters

# Initialize the MCTS search instance with minimal settings.
# Generation uses get_llm_response(...), evaluation/fallback uses get_api_response(...)
mcts = MCTSSearch(
    num_simulations=5,
    exploration_const=1.41,
    max_depth=1,
    max_fallback=3,
    num_candidates=1,
    verbose=True
)

initial_problem = "Solve the integral: \\( \\int_0^1 x^2 \\, dx \\)"
solution_node = mcts.search(initial_problem)

# Recursive printer for the MCTS tree.
def print_mcts_tree(node, indent=""):
    print(indent + repr(node))
    for child in node.children:
        print_mcts_tree(child, indent + "  ")

if solution_node:
    # Backtrack to the root node.
    root = solution_node
    while root.parent is not None:
        root = root.parent

    # Print the entire tree from the root.
    print_mcts_tree(root)

    # Extract and normalize the final boxed answer.
    final = extract_answer(solution_node.state)
    if final:
        solution_node.state = f"\\boxed{{{final}}}"
    print("\nFinal Answer Found:")
    print(solution_node.state)
else:
    # No solution found: use GPT/Gemini verifier directly for final answer.
    fallback_prompt = (
        f"Based on the problem: \"{initial_problem}\", please provide ONLY your final answer "
        "in the exact format \\boxed{<final answer>}. Do not include any chain-of-thought or extra explanation."
    )
    forced_final_answer = get_api_response(fallback_prompt).strip()
    final = extract_answer(forced_final_answer) or forced_final_answer
    forced_final_answer = f"\\boxed{{{final}}}"
    print("\nNo complete solution found. Forced final answer:")
    print(forced_final_answer)



Expanding node (depth 0):
Solve the integral: \( \int_0^1 x^2 \, dx \)

Fallback attempt 1: Okay, so I need to solve the integral of x squared from 0 to 1. Hmm, integrals can sometimes be tricky, but I remember that integration is essentially finding the area under a curve. In this case, the function is x squared, which is a parabola that opens upwards. The limits are from 0 to 1, so I'm looking for the area under that curve between those two points.

First, I recall that the integral of x^n is a standard formula. I think it's something like (x^(n+1))/(n+1) plus a constant, right? So, for x squared, which is x^2, n would be 2. Applying the formula, the integral should be (x^(2+1))/(2+1) = x^3/3. So, the indefinite integral of x squared is x cubed over 3.

But wait, since we're dealing with a definite integral from 0 to 1, I need to evaluate the indefinite integral at the upper limit and subtract its value at the lower limit. That means I should compute [x^3/3] from 0 to 1. 

Let me pl

# Evaluation of MCTS Search Method on the Math500 Dataset

This evaluation framework is designed to test the MCTS search method on a subset of the Math500 dataset. The framework is highly configurable via hyperparameters and forces the model to output only the final answer in the expected format (e.g., `\boxed{<final answer>}`) with no additional chain-of-thought text.

## Key Features

- **Unique Results File:**  
  A dedicated results file (e.g., `evaluation_results_mcts_test.json`) is used to save evaluation data. The file is cleared at the beginning of each run to ensure a fresh start without interference from previous evaluations.

- **Configurable Hyperparameters:**  
  You can adjust parameters such as:
  - `num_simulations`: Number of MCTS simulations used to explore reasoning paths.
  - `exploration_const`: The constant used in the UCB formula to balance exploration and exploitation.
  - `max_depth`: Maximum depth of the search tree.
  - `max_fallback`: Maximum number of fallback attempts to force the model to output a final answer.
  - `num_candidates`: Number of candidate final answers generated per node.
  
  These settings allow you to experiment with different reasoning strategies and trade-offs between search depth and the precision of the final answer.

- **Sample Selection:**  
  The `max_samples` parameter controls the number of problems from the dataset to evaluate. This lets you start by testing on a single sample and then scale the evaluation to larger subsets (e.g., 100 or more problems) as desired.

- **Final Answer Extraction and Fallback Mechanism:**  
  The evaluation process extracts only the final answer from the model’s output (using a helper like `extract_answer`), ensuring that only a concise final answer is compared against the expected solution. If no complete final answer is found within the search tree, a fallback prompt is issued to force the model to provide the final answer in the correct format.

- **Result Saving and Analysis:**  
  Each problem’s evaluation result—including the problem, the model’s raw response, the extracted final answer, and correctness—is saved to the results file. A summary report is then generated, which includes key metrics such as total evaluated problems, number of correct answers, and overall accuracy.

## Encouragement for Further Improvement

- **Experiment with the Exploration Constant:**  
  Try using different values for the exploration constant (e.g., 0.5, 1.41, 2.0) to see how they affect the balance between exploring new nodes and exploiting known promising ones. Compare the results and observe how the search tree structure and the final answer accuracy change with each setting.

- **Tuning Other Parameters:**  
  Experiment with other hyperparameters such as `num_simulations`, `max_depth`, and `num_candidates`. Adjusting these values can impact the thoroughness of the search and the likelihood of obtaining a complete final answer.

- **Document Your Observations:**  
  As you tweak the parameters, please comment on your expectations and the outcomes:
  - What changes do you observe when you adjust the exploration constant?
  - How does increasing the maximum depth influence the quality and correctness of the final answer?
  - Does generating more candidates per node lead to more accurate results?
  
  Sharing your observations and comparing them with results from other approaches (such as A* or the Tree of Thoughts method) can provide valuable insights and help guide further improvements.

This framework is designed to be flexible, so feel free to adjust the parameters and prompts to suit your research needs and push the performance of the MCTS method further.


In [None]:
import os
from tqdm import tqdm

def evaluate_mcts(max_samples=1):
    """
    Evaluate the MCTS search method on a subset of Math500, using GPT/Gemini
    (via get_api_response) for any fallback final‑answer requests.
    Assumes these helpers exist:
      - load_math500_dataset()
      - extract_answer(text)
      - compare_answers(correct, pred)
      - save_result(filename, result_dict)
      - load_existing_results(filename)
      - analyze_results(results_list)
      - get_llm_response(prompt)   # for candidate generation
      - get_api_response(prompt)   # for GPT/Gemini verification/fallback
    """
    os.makedirs("results", exist_ok=True)
    results_file = "results/evaluation_results_mcts_test.json"
    if os.path.exists(results_file):
        os.remove(results_file)

    dataset = load_math500_dataset()
    existing = load_existing_results(results_file)
    processed = {r['index'] for r in existing}

    mcts = MCTSSearch(
        num_simulations=5,
        exploration_const=1.41,
        max_depth=1,
        max_fallback=3,
        num_candidates=1,
        verbose=True
    )

    correct_count = 0
    evaluated = 0

    def collect_all_nodes(node):
        nodes = [node]
        for child in node.children:
            nodes.extend(collect_all_nodes(child))
        return nodes

    for idx, item in enumerate(tqdm(dataset, desc="Evaluating problems")):
        if idx in processed or evaluated >= max_samples:
            continue

        problem_text = item['problem']
        correct_answer = extract_answer(item['solution'])

        # Run MCTS search
        solution_node = mcts.search(problem_text)

        if solution_node:
            response = solution_node.state
            print("DEBUG: Found solution node:", response)
        else:
            print("DEBUG: No solution node found. Using GPT fallback.")
            fallback_prompt = (
                f"Based on the problem: \"{problem_text}\", provide ONLY your final answer "
                "in the exact format \\boxed{<final answer>}. Do not include any chain‑of‑thought or explanation."
            )
            response = get_api_response(fallback_prompt).strip()
            print("DEBUG: Fallback response:", response)

        predicted_answer = extract_answer(response) or ""
        if not predicted_answer:
            print("DEBUG: predicted_answer is empty. Raw response was:\n", response)

        is_correct = compare_answers(correct_answer, predicted_answer)
        save_result(results_file, {
            "index": idx,
            "problem": problem_text,
            "response": response,
            "correct_answer": correct_answer,
            "predicted_answer": predicted_answer,
            "is_correct": is_correct
        })

        if is_correct:
            correct_count += 1
        evaluated += 1
        print(f"Correct: {correct_count}/{evaluated} | Index: {idx}")

    final_results = load_existing_results(results_file)
    analyze_results(final_results)

# Example: evaluate a single sample
evaluate_mcts(max_samples=1)
# To test more problems, increase max_samples


Evaluating problems:   0%|          | 0/500 [00:00<?, ?it/s]


Expanding node (depth 0):
Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$

Fallback attempt 1: Alright, so I need to convert the rectangular coordinates (0,3) to polar coordinates. Hmm, okay. I remember that polar coordinates are (r, θ), where r is the distance from the origin to the point, and θ is the angle made with the positive x-axis. 

Let me recall the formulas for converting from rectangular to polar coordinates. I think it's something like r = sqrt(x² + y²) and θ = arctan(y/x). Yeah, that sounds right. So, in this case, the point is (0,3). Let me plug those values into the formulas.

First, calculating r. Since x is 0 and y is 3, plugging into the formula for r: r = sqrt(0² + 3²) = sqrt(0 + 9) = sqrt(9) = 3. Okay, so r is 3. That seems straightforward.

Now, for θ. The formula is θ = arctan(y/x). So plugging in the values, θ = arctan(3/0). Wait a second, division 

Evaluating problems: 100%|██████████| 500/500 [01:21<00:00,  6.16it/s] 

DEBUG: Found solution node: Okay, so I need to convert the rectangular coordinates (0, 3) to polar coordinates. Hmm, I remember that polar coordinates are represented as (r, θ), where r is the distance from the origin to the point, and θ is the angle made with the positive x-axis. 

First, let me recall the formulas for converting from rectangular (Cartesian) coordinates to polar. I think they are:

r = √(x² + y²)

θ = arctan(y/x)

Yeah, that sounds right. So, given the point (0, 3), which is on the y-axis, I can plug those values into the formulas.

Starting with r. Since x is 0 and y is 3, let's compute r:

r = √(0² + 3²) = √(0 + 9) = √9 = 3

Okay, so r is 3. That makes sense because the point is 3 units away from the origin, which is on the positive y-axis.

Now, for θ. The formula is θ = arctan(y/x). But wait, since x is 0, we have a division by zero here. Hmm, arctan(∞) or arctan(-∞)? Let me think.

When x is 0, the point is on the y-axis. So, if y is positive, it's on the positiv


