In [None]:
from openai import OpenAI
from dotenv import load_dotenv
from datasets import load_dataset
from transformers import AutoTokenizer
from IPython.display import clear_output

import os
import re
import json
import pandas as pd

from grading import grader

load_dotenv()
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")

# Openrouter Client
openrouter_client = OpenAI(
  base_url="https://openrouter.ai/api/v1",
  api_key=openrouter_api_key,
)

# Qwen/Qwen2.5-7B-Instruct Client
qwen_client = OpenAI(
  api_key="EMPTY",
  base_url="https://qwen.stephenxie.com/v1",
)

# Define Model
model = 'qwen/qwen3-8b'
huggingface_model = 'Qwen/Qwen3-8B'

tokenizer = AutoTokenizer.from_pretrained(huggingface_model)

def count_tokens(text):
  '''
  Count tokens in `text` using the `Qwen3-8B` tokenizer
  '''
  return len(tokenizer.encode(text))


In [None]:
## Utility Functions ##

def sanitize_json_string(json_str):
    """
    Cleans up common LLM JSON formatting errors.
    """
    # Replace backslashes
    sanitized_str = json_str.replace('\\', '\\\\')

    # Fix unquoted keys
    sanitized_str = re.sub(r'([{,]\s*)([a-zA-Z0-9_]+)(\s*:)', r'\1"\2"\3', sanitized_str)

    # Replace single quotes with double quotes
    sanitized_str = re.sub(r"'([^']*)'", r'"\1"', sanitized_str)

    # Handle improperly escaped newlines and tabs
    sanitized_str = sanitized_str.replace('\n', '\\n').replace('\t', '\\t')
    
    return sanitized_str

def parse_json_response(response):
    '''
    Robust JSON extraction and parsing from `response` 
    Returns type and data
    '''
    json_pattern = re.compile(r'\{[\s\S]*\}', re.MULTILINE)
    match = json_pattern.search(response)

    if match:
        json_str = match.group(0)

        clean_json_str = json_str.strip()

        # Try to directly parse first
        try:
            data = json.loads(clean_json_str)
            return data.get("type"), data
        except json.JSONDecodeError as e:
            # If it fails, sanitize
            print(f"Initial JSON decode failed: {e}")
            print(f"Raw JSON string (repr): {repr(clean_json_str)}")

            # Correct unquoted keys
            fixed_keys_str = re.sub(r'([{,]\s*)([a-zA-Z0-9_]+)(\s*:)', r'\1"\2"\3', clean_json_str)

            # Correct single quotes to double quotes.
            fixed_quotes_str = re.sub(r"'([^']*)'", r'"\1"', fixed_keys_str)

            # Correct backslashes and other control characters.
            fixed_escapes_str = re.sub(r'\\(?!["\\/bfnrtu])', r'\\\\', fixed_quotes_str)

            # Clean Up newlines and tabs
            final_sanitized_str = fixed_escapes_str.replace('\n', '\\n').replace('\t', '\\t')

            # Try to parse the fully sanitized string.
            try:
                data = json.loads(final_sanitized_str)
                return data.get("type"), data
            except json.JSONDecodeError as e:
                print(f"Sanitized JSON decode failed: {e}")
                print(f"Malformed JSON string: {json_str}")
                return None, None
    
    print(f"No JSON object found in response: {response}")
    return None, None


def isParallel(inference):
    block_type, _ = parse_json_response(inference)
    return block_type == "PARALLEL"

def isSerial(inference):
    block_type, _ = parse_json_response(inference)
    return block_type == "SERIAL"

def isCompleted(inference):
    block_type, _ = parse_json_response(inference)
    return block_type == "COMPLETED"

def track_tokens(level, output_tokens, reasoning_tokens, prompt_tokens, completion_tokens, token_stats):
    # Track total token usage
    if prompt_tokens > token_stats['max_prompt_tokens']:
        token_stats['max_prompt_tokens'] = prompt_tokens
    if completion_tokens > token_stats['max_completion_tokens']:
        token_stats['max_completion_tokens'] = completion_tokens

    token_stats["total_reasoning"] += reasoning_tokens
    token_stats["total_output"] += output_tokens

    # Track level-wise
    if level not in token_stats["by_level"]:
        token_stats["by_level"][level] = {"output": 0, "reasoning": 0, "calls": 0}
    
    token_stats["by_level"][level]["output"] += output_tokens
    token_stats["by_level"][level]["reasoning"] += reasoning_tokens
    token_stats["by_level"][level]["calls"] += 1

def extract_answer(answer, client):
    r"""
    Use regex to extract the contents of \\boxed{...}

    regex pattern: \\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}

    1. \\boxed\{  
        \\ - Matches a literal backslash (escaped because \ is a special character in regex)
        boxed - Matches the literal text "boxed"
        \{ - Matches a literal opening brace { (escaped because { is a special character in regex)

    2. ([^{}]*(?:\{[^{}]*\}[^{}]*)*) (The main capture group)
        2a. [^{}]*
            [^{}] - Character class that matches any character EXCEPT { or }
            * - Zero or more of the preceding character class
        2b. (?:\{[^{}]*\}[^{}]*)*
            (?:...) - Non-capturing group (groups the pattern but doesn't create a separate capture)
            \{ - Matches a literal opening brace {
            [^{}]* - Matches zero or more characters that aren't braces
            \} - Matches a literal closing brace }
            [^{}]* - Matches zero or more non-brace characters after the closing brace
            * - The whole non-capturing group can repeat zero or more times

    3. \}
        \} - Matches the final literal closing brace }
    """

    answer_pattern = r'\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}'
    matches = re.findall(answer_pattern, answer)

    if matches:
        return matches[-1]
    else: # Default to model extraction, if regex fails
        extraction_prompt = ''' 
            Extract the contents of the final \\boxed{} and return the value, and only this value.
        '''

        completion = client.chat.completions.create(
            model="qwen/qwen-2.5-7b-instruct",
            messages=[
                {
                    "role": "system", 
                    "content": extraction_prompt
                }, 
                {
                    "role": "user", 
                    "content": answer #TODO: should this be json.dumps(answer) ?
                }
            ]
        )

        return completion.choices[0].message.content

def get_baseline_stats(question, client=None, model=model):
    completion = client.chat.completions.create(
        model=model,
        messages=[{
            "role": "user", "content": question
        }]
    )

    # Get the response text
    inference = completion.choices[0].message.content

    # Track token usage
    token_usage = {"total_output": 0, "total_reasoning": 0}
    total_completion_tokens = completion.usage.completion_tokens
    token_usage['total_output'] = count_tokens(inference)
    token_usage['total_reasoning'] = total_completion_tokens - token_usage['total_output']
    
    return inference, token_usage

## Prompt Functions ##

def create_math500_prompt(question, forkJoin=False):
    '''
    Used by the authors of the math-500 evaluation, in order to use the PRM800K Parsing logic
    https://www.vals.ai/benchmarks/math500-03-11-2025
    '''

    forkJoinFormatting = '''
    **IMPORTANT**
        Note that this \\boxed answer term should occur ONLY inside of the COMPLETED block at recursion level 0
    ''' if forkJoin else ''

    prompt = f''' 
    Answer the following math question, given in LaTeX format, clearly and concisely, and present the final answer as:
    \\(\\boxed{{x}}\\), where X is the fully simplified solution.

    {forkJoinFormatting}

    Example:
        **Question:** \\(\\int_0^1 (3x^2 + 2x) \\, dx\\)
        **Solution: \\(\\int (3x^2 + 2x) \\,dx = x^3 + x^2 + C\\) 
            Evaluating from 0 to 1: \\((1^3 + 1^2) - (0^3 + 0^2) = 1 + 1 - 0 = 2 \\boxed{2}\\)
        ** Answer: \\boxed{2}

    Now, solve the following question: 
    {question}
    '''
    return prompt

def createSystemPrompt(current_depth, max_depth=2):
    if current_depth == max_depth:
        prompt = f''' 
        You are an expert model that decomposes complex tasks into parallel subtasks, 
        and uses recursive calls, structured in the following format, to evaluate and execute these subtasks.

        You are at the maximum recursion depth ({current_depth}/{max_depth}) allowed for this instance, and therefore must abide by more restrictive output formatting guidelines

        TASK: Analyze the provided problem and format your response in EXACTLY the provided format

        **RULES**

        0. **CRUCIAL**: When writing LaTeX inside a JSON string, you **MUST** escape the backslash character `\\`. This means you must write `\\\\` instead of `\\`. 
            For example, to write `\\sqrt{{x}}`, you must type it as `"\\\\sqrt{{x}}"`
        1.  **You MUST respond with only a single, valid JSON object.** Your entire response must be the JSON object itself and nothing else.
        2.  **ALWAYS** use double quotes for all property names and string values, i.e. always use \" instead of \' for JSON formatting
        3.  **DO NOT** under any circumstances include introductory phrases like "Here is the JSON:" or any other explanatory text.
        4.  **DO NOT** under any circumstances include any text formatting (e.g. tab or newline characters) in your response; ensuring that your response is pure json
        5.  **ENSURE** that ALL brackets are properly opened and closed
        6.  Continue until COMPLETED at level={current_depth}

        **OUTPUT FORMATTING**

        {{
            "type": "SERIAL",
            "level": {current_depth},
            "inference": "relevant serial inference"
        }}
        
        {{
            "type": "COMPLETED",
            "level": {current_depth},
            "result": "complete solution to the query"
        }}
        
        Output ONLY valid JSON, no other text.
        '''
    else:
        prompt = f'''
        You are an expert model that decomposes complex tasks into parallel subtasks, 
        and uses recursive calls, structured in the following JSON format, to evaluate and execute these subtasks.

        Current recursion depth: {current_depth}/{max_depth}

        TASK: Analyze the provided problem and format your response in EXACTLY the provided JSON format;

        **RULES**

        **You MUST respond with only a single, valid JSON object.** Your entire response **MUST** be the JSON object itself and **NOTHING ELSE**.

        0. **ABSOLUTELY CRITICAL**: When writing LaTeX inside a JSON string, you **MUST** escape the backslash character `\\`. This means you must write `\\\\` instead of `\\`. 
            For example, to write `\\sqrt{{x}}`, you must type it as `"\\\\sqrt{{x}}"`
        1.  **ABSOLUTELY CRITICAL**: ALWAYS use double quotes for **ALL** property names and string values, i.e. always use \" instead of \' for JSON formatting
        2. **You MUST respond with only a single, valid JSON object.** Your entire response must be the JSON object itself and nothing else.
        3.  **DO NOT** under any circumstances include introductory phrases like "Here is the JSON:" or any other explanatory text.
        4.  **DO NOT** under any circumstances include any text formatting (e.g. tab or newline characters) in your response; ensuring that your response is pure json
        5.  **ENSURE** that ALL brackets are properly opened and closed
        6.  Continue until COMPLETED at level={current_depth}

        **DECOMPOSITION GUIDELINES**

        - Each fork must be fully self-contained and each fork's input must represent the entirety of the question, solvable without any additional context
        - ONLY decompose tasks when subtasks are independent AND require substantial work (>30 seconds of human effort)
            - PARALLEL tasks may be used to investigate possible avenues for solutions, especially if there is not one clear path forward
        - DO NOT decompose basic mathematical operations, nor single-step algebraic manipulations (substitution, solving for one variable, simple derivatives, etc.)
        - SERIAL blocks may contain either 
            1. an integration of the previous PARALLEL block(s), where work is still required in solving the problem
            2. a step in the solution of the problem that is only possible to execute serially, as it depends on prior steps or informs future steps

        **OUTPUT FORMATTING**

        {{
            "type": "PARALLEL",
            "level": {current_depth},
            "forks": [
                {{"input": "specific self-contained question"}},
                {{"input": "another independent question"}}
            ]
        }}

        {{
            "type": "SERIAL",
            "level": {current_depth},
            "inference": "relevant serial inference"
        }}

        {{
            "type": "COMPLETED",
            "level": {current_depth},
            "result": "The answer is \\\\boxed{{x}}"
        }}

        **DO NOT** deviate from this format. All keys and string values must be enclosed in double quotes. All backslashes must be escaped like this: \\\\"
        '''

    return prompt

def createStatePrompt(question, partial_answer=None):
    prompt = f'''
    Task:
    
    {question}

    ####
    
    Execute the next logical step {{one of: PARALLEL, SERIAL, COMPLETED}} in solving this problem.
    
    **CRUCIAL**
        Produce exactly one valid JSON object that represents the next step forward.

    The response progress up until this point is shown below:

    {partial_answer if partial_answer is not None else ""}
    '''

    return prompt

## Parallel/Recursive LLM Functions ##

def processParallelBlock(inference, current_depth, max_depth, client, token_stats):
    '''
    Extract and process the PARALLEL block in the inference
    '''
    block_type, data = parse_json_response(inference)
    
    if block_type != "PARALLEL":
        return "", ""

    forks = data.get("forks", [])

    processed_forks, complete_processed_forks = [], []
    
    for fork in forks:
        question = fork.get("input", "")
        if not question:
            continue
            
        # Recursive processing call
        completed_block, complete_trace, _ = llmForkJoin(question, current_depth + 1, max_depth, client, token_stats=token_stats)

        final_result = None
        try:
            parsed_block = json.loads(completed_block.strip())
            final_result = parsed_block.get('result')
        except Exception as e:
            final_result = {
                "type": "ERROR",
                "level": current_depth,
                "message": f"An error occurred: {str(e)}",
                "raw": completed_block
            }

        # Parse the JSON strings back to objects
        complete_trace_obj = json.loads(complete_trace) if complete_trace else {}

        processed_forks.append({'question': question, 'answer': final_result})
        complete_processed_forks.append({'question': question, 'answer': complete_trace_obj})

    # Construct processed PARALLEL block
    processed_parallel_block = {
        "type": "PARALLEL",
        "level": current_depth,
        "forks": [{"input": fork["question"], "output": fork["answer"]} for fork in processed_forks]
    }

    # Construct complete processed PARALLEL block
    complete_processed_parallel_block = {
        "type": "PARALLEL", 
        "level": current_depth,
        "forks": [{"input": fork["question"], "output": fork["answer"]} for fork in complete_processed_forks]
    }

    return json.dumps(processed_parallel_block), json.dumps(complete_processed_parallel_block)

def llmForkJoin(question, current_depth=0, max_depth=2, client=None, model=model, token_stats=None):
    '''
    Main function to handle recursive parallel task decomposition
    '''

    if token_stats is None:
        token_stats = {
            "total_output": 0, 
            "total_reasoning": 0, 
            "max_prompt_tokens": 0, 
            "max_completion_tokens": 0, 
            "by_level": {}
        }

    if current_depth > max_depth:
        obj = {
            "type": "COMPLETED", 
            "level": current_depth, 
            "result": "Maximum Recursion Depth Exceeded! Stopping Inference."
        }
        return json.dumps(obj), "", token_stats

    # Create 'system' and 'user' prompts
    system_prompt = createSystemPrompt(current_depth, max_depth)
    state_prompt = createStatePrompt(question)
    accumulated_response, complete_accumulated_response, completed_block = [], [], ""

    # Initial context for the conversation
    current_context = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": state_prompt}
    ]

    while True:        
        try:
            ## Model Interactions ##
            # Create the model chat.completions object
            completion = client.chat.completions.create(
                model=model,
                messages=current_context
            )

            # Get the response text
            new_inference = completion.choices[0].message.content.strip()

            ## Track Token Usage ##
            # Get and cache the total number of output_tokens used
            total_output_tokens = completion.usage.completion_tokens
            output_tokens = count_tokens(new_inference)
            reasoning_tokens = total_output_tokens - output_tokens
            prompt_tokens = completion.usage.prompt_tokens

            track_tokens(current_depth, output_tokens, reasoning_tokens, prompt_tokens, total_output_tokens, token_stats)

            ## Handle Model Response ##
            # Check if the inference is completed
            if isCompleted(new_inference):
                complete_accumulated_response.append(json.loads(new_inference))
                completed_block = new_inference
                break
            # Check if the new inference contains a PARALLEL block
            elif isParallel(new_inference):
                processed_block, complete_processed_block = processParallelBlock(new_inference, current_depth, max_depth, client, token_stats)
                accumulated_response.append(json.loads(processed_block))
                complete_accumulated_response.append(json.loads(complete_processed_block))
            # Check if the new inference contains a SERIAL block
            elif isSerial(new_inference):
                accumulated_response.append(json.loads(new_inference))
                complete_accumulated_response.append(json.loads(new_inference))
            # Otherwise break due to invalid JSON tag
            else:
                error_obj = {
                    "type": "ERROR", 
                    "level": current_depth, 
                    "message": "INVALID JSON FORMATTING",
                    "raw": new_inference
                }
                accumulated_response.append(error_obj)
                complete_accumulated_response.append(error_obj)

                # Print statement for error monitoring
                print("ERROR: Invalid JSON was produced by the model, exiting current inference.")
                break

            # Update conversation context for continuation
            current_context = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": createStatePrompt(question, json.dumps(accumulated_response, indent=2))}
            ]

        except Exception as e:
            print(f"Error in llmForkJoin at depth {current_depth}: {e}")
            error_obj = {
                "type": "ERROR",
                "level": current_depth,
                "message": f"An error occurred: {str(e)}"
            }
            complete_accumulated_response.append(error_obj)
            completed_block = json.dumps(error_obj)

    return completed_block, json.dumps(complete_accumulated_response, indent=2), token_stats


#### Testing

In [None]:
question = 'If $f(x) = \\frac{3x-2}{x-2}$, what is the value of $f(-2) +f(-1)+f(0)$? Express your answer as a common fraction.'
answer = '\\frac{14}{3}'

question = create_math500_prompt(question, True)
completed_block, accumulated_inference, token_usage = llmForkJoin(question, client=openrouter_client, model=model)

print('## Decomposition Inference ##\n', accumulated_inference, '\n')
print('## Decomposition Tokens ##\n', token_usage, '\n')

decomp_soln = extract_answer(completed_block, client=openrouter_client)

print('Decomposition Correct?: ', grader.grade_answer(decomp_soln, answer))

inf, baseline_token_usage = get_baseline_stats(create_math500_prompt(question, False), client=openrouter_client, model=model)
print('## Baseline Inference ##\n', inf, '\n')
print('## Baseline Tokens##\n', baseline_token_usage, '\n')

baseline_soln = extract_answer(inf, openrouter_client)

print('Baseline Correct?: ', grader.grade_answer(baseline_soln, answer))

#### Eval

In [None]:
## Load Dataset
ds = load_dataset("math-ai/math500")

In [None]:
df_list, count = [], 1

for row in ds['test']:
    print(f"Starting problem {row['unique_id']}, {count}/500")
    
    resp_dict = {'problem': row['problem'], 'answer': row['answer']}
    problem = row['problem']

    print("Correct Solution: ", row['answer'])

    ## forkJoin Inference/Analysis ##
    forkJoin_question = create_math500_prompt(problem, True)
    completed_block, accumulated_inference, token_usage = llmForkJoin(forkJoin_question, client=openrouter_client, model=model)
    resp_dict['decomp_inference'] = accumulated_inference
    resp_dict['decomp_tokens'] = token_usage

    # Print Decomposition Inference for Monitoring
    print('## Decomposition Inference ##\n', accumulated_inference)

    extracted_answer = extract_answer(completed_block, openrouter_client)
    resp_dict['decomp_extracted_answer'] = extracted_answer
    resp_dict['decomp_correct'] = grader.grade_answer(extracted_answer, row['answer'])

    print(f"\tCompleted Decomposition Inference and Analysis - Tokens: {token_usage['total_output'] + token_usage['total_reasoning']} - Correct: {resp_dict['decomp_correct']}\n")

    # ## Baseline Inference/Analysis ##
    # baseline_question = create_math500_prompt(problem, False)
    # baseline_inference, baseline_token_usage = get_baseline_stats(baseline_question, client=openrouter_client, model=model)
    # resp_dict['baseline_inference'] = baseline_inference
    # resp_dict['baseline_tokens'] = baseline_token_usage

    # baseline_extracted_answer = extract_answer(baseline_inference, openrouter_client)
    # resp_dict['baseline_extracted_answer'] = baseline_extracted_answer
    # resp_dict['baseline_correct'] = grader.grade_answer(baseline_extracted_answer, row['answer'])

    # print(f"\tCompleted Baseline Inference and Analysis - Tokens: {baseline_token_usage['total_output'] + baseline_token_usage['total_reasoning']} - Correct: {resp_dict['baseline_correct']}")
    
    # Append to df_list
    df_list.append(resp_dict)

    # Increment tally
    count += 1

    # Save every 25 rows
    if count % 25 == 0:
        partial_df = pd.DataFrame(df_list)
        partial_df.to_csv('results.csv')

        # Clear jupyter Output
        clear_output(wait=False)

df = pd.DataFrame(df_list)
df.to_csv('results.csv')

## Note: Saved until 299/500 - restart from 300/500 [test/geometry/1097.json]

#### Analysis

In [None]:
df = pd.read_csv("results_0_299.csv")

In [None]:
## Get the number of correct answers
print("Correct: ", df[df["decomp_correct"] == True].shape[0] / df.shape[0])

In [None]:
decomp_df = df[df['decomp_inference'].str.contains("PARALLEL")]
print("Decomposition Percentage: ", decomp_df.shape[0] / df.shape[0])
print("Decomposition Accuracy: ", decomp_df[decomp_df["decomp_correct"] == True].shape[0] / decomp_df.shape[0])

In [None]:
print(decomp_df.shape[0])

In [None]:
decomp_df.to_csv("recursive_llm_decompositions.csv")

In [None]:
err_df = df[df['decomp_inference'].str.contains("ERROR")]
print("ERROR Percentage: ", err_df.shape[0] / df.shape[0])