In [4]:
from openai import OpenAI
from dotenv import load_dotenv

import os
import re

# Requires that api keys are stored in .env
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = api_key

client = OpenAI()

In [None]:
def createPrompt(question, current_depth, max_depth=3):
    prompt = f'''
    You are an expert reasoning model that can delegate complex subtasks.
    Specifically, you are tasked with performing intelligent task decomposition and utilizing the parallelism that is accessible to you.

    When you need to perform a specific calculation, wrap it in:
        <FORK>
        QUESTION: "specific, self-contained question with all necessary context"
        </FORK>

    Group independent, parallelizable subtask forks within:
        <PARALLEL stage="[stage name]">
            [possibly multiple FORK blocks]
            <INTEGRATION>
                STRATEGY: "comprehensive strategy for integrating the above QUESTION/ANSWER pairs"
            </INTEGRATION>
        </PARALLEL>

    After execution, each FORK will be updated to:
        <FORK>
        QUESTION: "original question", ANSWER: "question response" 
        </FORK>

    A <JOIN> block will then be added immediately after the <PARALLEL> block:
        <JOIN stage="[stage name]>
            INTEGRATED RESULT: "result of executing the integration strategy on all ANSWER fields" 
        </JOIN>

    The current recursion depth is {current_depth}, and the max recursion depth is {max_depth}. 
    CRITICAL: If current depth == max depth, then simply provide a direct answer, and DO NOT use any <PARALLEL> blocks.
    
    BEST PRACTICES:
    1. When approaching a complex problem, identify subtasks that are independent (can be solved without waiting for other sub-threads), 
        and well-defined (each fork should contain sufficient context to be executed independently).
    2. Handle dependencies: Use sequential parallel blocks for multi-stage processes.
    3. Please limit the number of FORK blocks within each PARALLEL block to 3-5

    ----

    Your task: {question}
    Approach: Analyze if this benefits from parallel decomposition. If yes, create appropriate FORK/JOIN structure, if no, then answer the question directly.
    '''

    return prompt

def completedParallel(inference, num_parallel_blocks):
    '''
    Check if there is a new PARALLEL block to process
    '''
    num_matches = 0
    parallel_pattern = r'<PARALLEL\s+stage="([^"]*)">(.*?)</PARALLEL>'
    matches = re.finditer(parallel_pattern, inference, re.DOTALL | re.IGNORECASE)
    for match in matches:
        num_matches += 1
    return num_matches > num_parallel_blocks

def processParallelBlock(inference, current_depth, max_depth, client):
    '''
    Extract and process the most recent PARALLEL block in the inference thread
    '''
    parallel_pattern = r'<PARALLEL\s+stage="([^"]*)">(.+?)</PARALLEL>'
    matches = list(re.finditer(parallel_pattern, inference, re.DOTALL | re.IGNORECASE))

    if not matches:
        return "" # No PARALLEL blocks found

    last_parallel_block = matches[-1]
    stage = last_parallel_block.group(1)
    contents = last_parallel_block.group(2)

    fork_pattern = r'<FORK>\s*QUESTION:\s*"([^"]*)"'
    integration_pattern = r'<INTEGRATION>\s*STRATEGY:\s*"([^"]*)"\s*</INTEGRATION>'
    integration_match = re.search(integration_pattern, contents, re.DOTALL | re.IGNORECASE)

    processed_forks = []
    for fork_match in re.finditer(fork_pattern, contents, re.DOTALL | re.IGNORECASE):
        question = fork_match.group(1)
        answer = llmForkJoin(question, current_depth + 1, max_depth, client)
        processed_forks.append({'question': question, 'answer': answer})
    
    # Construct processed PARALLEL block
    processed_parallel_block = f"<PARALLEL stage={stage}>"
    for fork in processed_forks:
        processed_fork = f"\n\t<FORK>\n\t\tQUESTION:{fork['question']}, ANSWER:{fork['answer']}\n\t</FORK>"
        processed_parallel_block += processed_fork

    if integration_match:
        processed_parallel_block += f"\n\t<INTEGRATION>\n\t\tSTRATEGY: {integration_match.group(1)}\n\t</INTEGRATION>"
    processed_parallel_block += "</PARALLEL>"

    # Add a prompt to the model to continue the original stream
    continuation_prompt = f"{processed_parallel_block}\n\nNow complete the JOIN step as specified, and then continue to solve the problem."
    return continuation_prompt

def llmForkJoin(question, current_depth=0, max_depth=3, client=None):
    '''
    Main function to handle recursive parallel task decomposition
    '''

    if current_depth > max_depth:
        return "Maximum Recursion Depth Exceeded! Stopping Inference."

    prompt = createPrompt(question, current_depth, max_depth)
    num_parallel_blocks = 0
    processed_inference = ""

    # Initial context for the conversation
    current_context = [{"role": "user", "content": prompt}]

    while True:
        print(f"Creating a new stream at recursion level {current_depth}.")

        # Create stream with current context
        stream = client.responses.create(
            model="o4-mini",
            input=current_context,
            stream=True,
        )

        inference = ""

        # Stream and monitor for completed PARALLEL blocks
        for event in stream:
            if event.type == 'response.output_text.delta':
                inference += event.delta

                if completedParallel(inference, num_parallel_blocks):
                    # Interrupt: process the parallel block
                    processed_parallel_block = processParallelBlock(inference, current_depth, max_depth, client)

                    # Update context with the processed inference + processed parallel block
                    processed_inference += inference
                    processed_inference = processed_inference.replace(
                        list(re.finditer(r'<PARALLEL\s+stage="([^"]*)">(.+?)</PARALLEL>',
                                         processed_inference, re.DOTALL))[-1].group(0),
                        processed_parallel_block
                    )

                    # Update conversation context for continuation
                    current_context = [
                        {"role": "user", "content": prompt},
                        {"role": "assistant", "content": processed_inference}
                    ]

                    num_parallel_blocks += 1
                    break # Break from current stream to restart with new context

        else:
            # Stream completed without new PARALLEL blocks
            processed_inference += inference
            break # Break from while loop

    return processed_inference


In [10]:
question = ''' 
    Solve this system: Find x where 2x + 3y = 15, 4x - y = 7, and also calculate the derivatives of f(x) = x³ + 2x² - 5x + 1 at three points: x = 1, x = 2, x = 3
'''
response = llmForkJoin(question, client=client)
print(response)

Creating a new stream.
Creating a new stream.
Creating a new stream.
Creating a new stream.
<PARALLEL stage=System solving and derivatives>
	<FORK>
		QUESTION:Solve the system of equations 2x + 3y = 15 and 4x – y = 7 for x and y., ANSWER:This system is small and is most easily solved by straightforward elimination—no need for parallel sub‐tasks.

From  
 (1) 2x + 3y = 15  
 (2) 4x – y = 7  

Solve (2) for y:  
 4x – y = 7 ⇒ y = 4x – 7  

Substitute into (1):  
 2x + 3(4x – 7) = 15  
 2x + 12x – 21 = 15  
 14x = 36  
 x = 36/14 = 18/7  

Then y = 4·(18/7) – 7 = 72/7 – 49/7 = 23/7  

Answer:  
 x = 18/7, y = 23/7
	</FORK>
	<FORK>
		QUESTION:Compute the derivative f′(x) of f(x) = x³ + 2x² – 5x + 1 and evaluate f′ at x = 1, 2, and 3., ANSWER:The problem is simple enough that breaking it into parallel subtasks would add unnecessary overhead. So here’s the direct solution:

1. Compute the derivative:
   f(x) = x³ + 2x² – 5x + 1  
   ⇒ f′(x) = 3x² + 4x – 5

2. Evaluate at x = 1, 2, 3:
   • f′