In [85]:
from openai import OpenAI
from dotenv import load_dotenv

import os
import re

import tiktoken

# Requires that api keys are stored in .env
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = api_key

client = OpenAI()

#### Working! Version 1 (streaming implementation)

In [None]:
def createPrompt(question, current_depth, max_depth=3):
    prompt = f'''
    You are an expert reasoning model that can delegate complex subtasks.
    Specifically, you are tasked with performing intelligent task decomposition and utilizing the parallelism that is accessible to you.

    When you need to perform a specific calculation, wrap it in:
        <FORK>
        QUESTION: "specific, self-contained question with all necessary context"
        </FORK>

    Group independent, parallelizable subtask forks within:
        <PARALLEL stage="[stage name]">
            [possibly multiple FORK blocks]
            <INTEGRATION>
                STRATEGY: "comprehensive strategy for integrating the above QUESTION/ANSWER pairs"
            </INTEGRATION>
        </PARALLEL>

    After execution, each FORK will be updated to:
        <FORK>
        QUESTION: "original question", ANSWER: "question response" 
        </FORK>

    A <JOIN> block will then be added immediately after the <PARALLEL> block:
        <JOIN stage="[stage name]>
            INTEGRATED RESULT: "result of executing the integration strategy on all ANSWER fields" 
        </JOIN>

    The current recursion depth is {current_depth}, and the max recursion depth is {max_depth}. 
    CRITICAL: If current depth == max depth, then simply provide a direct answer, and DO NOT use any <PARALLEL> blocks.
    
    BEST PRACTICES:
    1. When approaching a complex problem, identify subtasks that are independent (can be solved without waiting for other sub-threads), 
        and well-defined (each fork should contain sufficient context to be executed independently).
    2. Handle dependencies: Use sequential parallel blocks for multi-stage processes.
    3. Please limit the number of FORK blocks within each PARALLEL block to 3-5

    ----

    Your task: {question}
    Approach: Analyze if this benefits from parallel decomposition. If yes, create appropriate FORK/JOIN structure, if no, then answer the question directly.
    '''

    return prompt

def completedParallel(inference, num_parallel_blocks):
    '''
    Check if there is a new PARALLEL block to process
    '''
    parallel_pattern = r'<PARALLEL\s+stage="([^"]*)">(.*?)</PARALLEL>'
    matches = list(re.finditer(parallel_pattern, inference, re.DOTALL | re.IGNORECASE))
    return len(matches) > num_parallel_blocks

def processParallelBlock(inference, current_depth, max_depth, client):
    '''
    Extract and process the most recent PARALLEL block in the inference thread
    '''
    parallel_pattern = r'<PARALLEL\s+stage="([^"]*)">(.+?)</PARALLEL>'
    matches = list(re.finditer(parallel_pattern, inference, re.DOTALL | re.IGNORECASE))

    if not matches:
        return "" # No PARALLEL blocks found

    last_parallel_block = matches[-1]
    stage = last_parallel_block.group(1)
    contents = last_parallel_block.group(2)
    original_block_text = last_parallel_block.group(0)  # Store the exact original text

    fork_pattern = r'<FORK>\s*QUESTION:\s*"([^"]*)"'
    integration_pattern = r'<INTEGRATION>\s*STRATEGY:\s*"([^"]*)"\s*</INTEGRATION>'
    integration_match = re.search(integration_pattern, contents, re.DOTALL | re.IGNORECASE)

    processed_forks = []

    for fork_match in re.finditer(fork_pattern, contents, re.DOTALL | re.IGNORECASE):
        question = fork_match.group(1)
        answer = llmForkJoin(question, current_depth + 1, max_depth, client)
        processed_forks.append({'question': question, 'answer': answer})

    # Construct processed PARALLEL block
    processed_parallel_block = f'<PARALLEL stage="{stage}">'
    for fork in processed_forks:
        processed_fork = f'\n\t<FORK>\n\t\tQUESTION: "{fork['question']}", ANSWER: "{fork['answer']}"\n\t</FORK>'
        processed_parallel_block += processed_fork

    if integration_match:
        processed_parallel_block += f'\n\t<INTEGRATION>\n\t\tSTRATEGY: "{integration_match.group(1)}"\n\t</INTEGRATION>'
    processed_parallel_block += "\n</PARALLEL>"

    # Add a prompt to the model to continue the original stream
    continuation_prompt = f"{processed_parallel_block}\n\nNow complete the JOIN step as specified, and then continue to solve the problem."
    return continuation_prompt, original_block_text

def llmForkJoin(question, current_depth=0, max_depth=3, client=None, model="o4-mini"):
    '''
    Main function to handle recursive parallel task decomposition
    '''

    if current_depth > max_depth:
        result = "Maximum Recursion Depth Exceeded! Stopping Inference."
        return result

    prompt = createPrompt(question, current_depth, max_depth)
    num_parallel_blocks = 0
    processed_inference = ""

    # Initial context for the conversation
    current_context = [{"role": "user", "content": prompt}]

    while True:
        print(f"Creating a new stream at recursion level {current_depth}.")

        # Create stream with current context
        stream = client.responses.create(
            model=model,
            input=current_context,
            stream=True,
        )

        inference = ""

        # Stream and monitor for completed PARALLEL blocks
        for event in stream:
            if event.type == 'response.output_text.delta':
                inference += event.delta

                if completedParallel(inference, num_parallel_blocks):
                    # Interrupt: process the parallel block
                    processed_parallel_block, original_block_text = processParallelBlock(
                        inference, current_depth, max_depth, client
                    )

                    # Update context with the processed inference + processed parallel block
                    processed_inference += inference
                    processed_inference = processed_inference.replace(
                        original_block_text,
                        processed_parallel_block
                    )

                    # Update conversation context for continuation
                    current_context = [
                        {"role": "user", "content": prompt},
                        {"role": "assistant", "content": processed_inference}
                    ]

                    num_parallel_blocks += 1
                    break # Break from current stream to restart with new context

        else:
            # Stream completed without new PARALLEL blocks
            processed_inference += inference
            break # Break from while loop

    return processed_inference


In [45]:
question = ''' 
    Solve this system: Find x where 2x + 3y = 15, 4x - y = 7, and also calculate the derivatives of f(x) = x³ + 2x² - 5x + 1 at three points: x = 1, x = 2, x = 3
'''
response = llmForkJoin(question, client=client)
print(response)

Creating a new stream at recursion level 0.
Creating a new stream at recursion level 1.
ResponseUsage(input_tokens=440, input_tokens_details=InputTokensDetails(cached_tokens=0), output_tokens=519, output_tokens_details=OutputTokensDetails(reasoning_tokens=320), total_tokens=959)
Creating a new stream at recursion level 1.
ResponseUsage(input_tokens=454, input_tokens_details=InputTokensDetails(cached_tokens=0), output_tokens=309, output_tokens_details=OutputTokensDetails(reasoning_tokens=192), total_tokens=763)
Creating a new stream at recursion level 0.
Creating a new stream at recursion level 1.
ResponseUsage(input_tokens=440, input_tokens_details=InputTokensDetails(cached_tokens=0), output_tokens=451, output_tokens_details=OutputTokensDetails(reasoning_tokens=256), total_tokens=891)
Creating a new stream at recursion level 1.
ResponseUsage(input_tokens=450, input_tokens_details=InputTokensDetails(cached_tokens=0), output_tokens=238, output_tokens_details=OutputTokensDetails(reasoning

#### Version 2 (state management)

In [None]:
def createSystemPrompt(current_depth, max_depth=3):
    prompt = f'''
    You are an expert reasoning model that decomposes complex tasks into parallel subtasks, 
    and then uses recursive calls, structured in the following format, to evaluate and execute these subtasks.

    Current recursion depth: {current_depth}/{max_depth}

    TASK: Analyze the provided problem and format your response in EXACTLY the provided format

    **DECOMPOSITION GUIDELINES**
    - ONLY decompose when subtasks are independent AND require substantial work (>10 seconds of human effort)
    - DO NOT decompose basic mathematical operations, nor single-step algebraic manipulations (substitution, solving for one variable)

    **FORMAT**
    PARALLEL TASKS (for independent subtasks):
    <PARALLEL stage="descriptive_name">
        <FORK>
            INPUT: "specific self-contained question"
        </FORK>
        <FORK>
            INPUT: "another independent question"
        </FORK>
        <INTEGRATION>
            STRATEGY: "how to combine the fork results"
        </INTEGRATION>
    </PARALLEL>

    JOIN (after parallel processing):
    <JOIN stage="matching_parallel_stage_name">
        INTEGRATED RESULT: "combined result using integration strategy"
    </JOIN>

    SERIAL TASKS (for dependent subtasks):
    <SERIAL stage="descriptive_name">
        INPUT: "question requiring previous results"
        OUTPUT: "your complete solution with all intermediate steps shown"
    </SERIAL>

    COMPLETION (final answer - use only once for the FINAL answer):
    <COMPLETED level="current_recursion_level">
        FINAL ANSWER: "complete solution"
    </COMPLETED>

    **RULES**
    - At max recursion depth, you may only use the SERIAL and COMPLETED blocks
    - Each FORK must be fully self-contained and represent substantial work
    - Limit to 2-4 FORKs per PARALLEL block for meaningful parallelization
    - PARALLEL blocks require a subsequent JOIN block
    - Show complete work within SERIAL blocks rather than decomposing further
    - Continue until COMPLETED at level=0
    '''

    return prompt

def createStatePrompt(question, partial_answer=None):
    prompt = f'''Main task: {question}
    Execute the next logical step {{one of: PARALLEL, JOIN, SERIAL, COMPLETED}} in solving this problem.
    
    CRUCIAL: Produce exactly one block that represents the next step forward.

    The response progress up until this point is shown below:

    {partial_answer if partial_answer is not None else ""}
    '''

    return prompt

def isParallelBlock(inference):
    return bool(re.search(r'<PARALLEL\s+stage="([^"]*)">(.*?)</PARALLEL>', inference, re.DOTALL | re.IGNORECASE))

def isJoinBlock(inference):
    return bool(re.search(r'<JOIN\s+stage="([^"]*)">(.*?)</JOIN>', inference, re.DOTALL | re.IGNORECASE))

def isSerialBlock(inference):
    return bool(re.search(r'<SERIAL\s+stage="([^"]*)">(.*?)</SERIAL>', inference, re.DOTALL | re.IGNORECASE))

def isCompleted(inference):
    return bool(re.search(r'<COMPLETED\s+level="([^"]*)">(.*?)</COMPLETED>', inference, re.DOTALL | re.IGNORECASE))

# Global token tracking
token_stats = {"total_reasoning": 0, "total_output": 0, "by_level": {}}

def track_tokens(level, reasoning_tokens, output_tokens):
    '''
    Track total token usage
    '''
    token_stats["total_reasoning"] += reasoning_tokens
    token_stats["total_output"] += output_tokens
    if level not in token_stats["by_level"]:
        token_stats["by_level"][level] = {"reasoning": 0, "output": 0, "calls": 0}
    token_stats["by_level"][level]["reasoning"] += reasoning_tokens
    token_stats["by_level"][level]["output"] += output_tokens
    token_stats["by_level"][level]["calls"] += 1

def processParallelBlock(inference, current_depth, max_depth, client):
    '''
    Extract and process the most recent PARALLEL block in the inference thread
    '''
    # Extract all PARALLEL blocks in the current inference thread
    parallel_pattern = r'<PARALLEL\s+stage="([^"]*)">(.+?)</PARALLEL>'
    matches = list(re.finditer(parallel_pattern, inference, re.DOTALL | re.IGNORECASE))

    if not matches:
        return "", "" # No PARALLEL blocks found

    # Extract the most recent PARALLEL block
    last_parallel_block = matches[-1]
    stage, contents = last_parallel_block.group(1), last_parallel_block.group(2)

    # Extract FORK blocks from the larger PARALLEL block
    fork_pattern = r'<FORK>\s*INPUT:\s*"([^"]*)"'
    processed_forks = []
    for fork_match in re.finditer(fork_pattern, contents, re.DOTALL | re.IGNORECASE):
        question = fork_match.group(1)

        # Recursive processing call
        answer = llmForkJoin(question, current_depth + 1, max_depth, client)
        processed_forks.append({'question': question, 'answer': answer})

    # Construct processed PARALLEL block
    processed_parallel_block = f'<PARALLEL stage="{stage}">'
    for fork in processed_forks:
        processed_fork = f'\n\t<FORK>\n\t\tINPUT: "{fork['question']}", \n\t\tOUTPUT: "{fork['answer']}"\n\t</FORK>'
        processed_parallel_block += processed_fork

    # Extract INTEGRATION blocks from the larger PARALLEL block
    integration_pattern = r'<INTEGRATION>\s*STRATEGY:\s*"([^"]*)"\s*</INTEGRATION>'
    integration_match = re.search(integration_pattern, contents, re.DOTALL | re.IGNORECASE)

    # Add INTEGRATION block to the processed PARALLEL block
    if integration_match:
        processed_parallel_block += f'\n\t<INTEGRATION>\n\t\tSTRATEGY: "{integration_match.group(1)}"\n\t</INTEGRATION>'
    processed_parallel_block += "\n</PARALLEL>"

    return processed_parallel_block

def llmForkJoin(question, current_depth=0, max_depth=3, client=None, model="o4-mini"):
    '''
    Main function to handle recursive parallel task decomposition
    '''

    if current_depth > max_depth:
        return "Maximum Recursion Depth Exceeded! Stopping Inference."

    # Create 'system' and 'user' prompts
    system_prompt = createSystemPrompt(current_depth, max_depth)
    state_prompt = createStatePrompt(question)
    accumulated_response = ""

    # Initial context for the conversation
    current_context = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": state_prompt}
    ]

    while True:
        print(f"Creating a new thread at recursion level {current_depth}.")

        # Create the model responses object
        response = client.responses.create(
            model=model,
            input=current_context
        )

        # Get and cache the total number of output_tokens and reasoning_tokens used
        output_tokens = response.usage.output_tokens
        reasoning_tokens = response.usage.output_tokens_details.reasoning_tokens
        track_tokens(current_depth, reasoning_tokens, output_tokens)

        # Get the response text
        new_inference = response.output_text

        # Debugging/Print Statement
        #print(new_inference, '\n')

        # Check if the inference is completed
        if isCompleted(new_inference):
            accumulated_response += '\n' + new_inference
            break
        # Check if the new inference contains a PARALLEL block
        elif isParallelBlock(new_inference):
            processed_block = processParallelBlock(new_inference, current_depth, max_depth, client)
            accumulated_response += '\n' + processed_block
        # Check if the new inference contains a JOIN or SERIAL block
        elif isJoinBlock(new_inference) or isSerialBlock(new_inference):
            accumulated_response += '\n' + new_inference
        # Otherwise throw an error 
        else:
            print(new_inference)
            raise Exception('Invalid model output format! The output MUST contain one of PARALLEL, JOIN, SERIAL, COMPLETED')

        # Update conversation context for continuation
        state_prompt = createStatePrompt(question, accumulated_response)

        current_context = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": state_prompt}
        ]

    return accumulated_response

def print_token_stats():
    print(f"\nTotal Tokens - Reasoning: {token_stats['total_reasoning']}, Output: {token_stats['total_output']}")
    for level, stats in token_stats['by_level'].items():
        parallelism = stats['calls'] if level > 0 else 1
        print(f"Level {level}: {stats['reasoning']} reasoning, {stats['output']} output, {stats['calls']} calls, parallelism: {parallelism}")

def print_baseline_token_stats(question, client=None, model='o4-mini'):
    current_context = [
        {"role": "user", "content": question}
    ]

    response = client.responses.create(
        model=model,
        input=current_context
    )

    output_tokens = response.usage.output_tokens
    reasoning_tokens = response.usage.output_tokens_details.reasoning_tokens

    # Get the response text
    output = response.output_text
    print('\n', output)

    print(f"\n[Baseline] Total Tokens - Reasoning: {reasoning_tokens}, Output: {output_tokens}")

def test(question, client):
    response = llmForkJoin(question, client=client)
    print('\n', response)
    print_token_stats()
    print_baseline_token_stats(question, client=client)

In [None]:
question = ''' 
    Solve this system: Find x where 2x + 3y = 15, 4x - y = 7, and also calculate the derivatives of f(x) = x³ + 2x² - 5x + 1 at three points: x = 1, x = 2, x = 3
'''

test(question, client)

Creating a new thread at recursion level 0.
Creating a new thread at recursion level 1.
Creating a new thread at recursion level 1.
Creating a new thread at recursion level 1.
Creating a new thread at recursion level 1.
Creating a new thread at recursion level 1.
Creating a new thread at recursion level 1.
Creating a new thread at recursion level 1.
Creating a new thread at recursion level 0.
Creating a new thread at recursion level 0.

 
<PARALLEL stage="solve_and_derivative">
	<FORK>
		INPUT: "Solve the system of equations 2x + 3y = 15 and 4x − y = 7 for x and y.", 
		OUTPUT: "
<SERIAL stage="Eliminate y">
    INPUT: "Multiply the second equation (4x − y = 7) by 3 to match the y-coefficient of the first equation."
    OUTPUT: "12x − 3y = 21"
</SERIAL>
<SERIAL stage="Add equations to eliminate y">
    INPUT: "Add 2x + 3y = 15 and 12x − 3y = 21."
    OUTPUT: "14x = 36, so x = 36/14 = 18/7."
</SERIAL>
<SERIAL stage="Solve for y">
    INPUT: "Substitute x = 18/7 into 4x − y = 7 and solve

In [None]:
question = ''' 
    Solve this system: Find x where 2x + 3y = 15, 4x - y = 7, and also calculate the derivatives of f(x) = x³ + 2x² - 5x + 1 at three points: x = 1, x = 2, x = 3
'''

test(question, client)