In [None]:
from vllm import LLM
from vllm.sampling_params import SamplingParams

In [None]:
llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", tensor_parallel_size=4, gpu_memory_utilization=0.5)

In [None]:
import json
from typing import Any, Callable

import string
import random

def generate_random_id(length: int) -> str:
    characters = string.ascii_letters + string.digits
    random_id = "".join(random.choice(characters) for _ in range(length))
    return random_id

tools = [
    {
        "type": "function",
        "function": {
            "name": "fork",
            "description": (
                "Fork a conversation with the provided messages and makes it run asynchronously. "
                "The forked agent will not automatically inherit your message history. "
                "Instead, you must explicitly provide the forked agent with the messages "
                "that you think are essential to the task you give to it. "
                "First, form message_ids as a list if Message IDs (MIDs). An MID is a alphanumerical tag "
                "in square brackets that is prepended to every system, user and assistant message. "
                "Example of an MID: [A1B2]. "
                "Second, pass your message/task/query to the forked agent in the message field. "
                "The tool returns a Thread ID (TID) to be later joined."
                "Do not forget to join the threads with the TIDs you own before exiting, "
                "to get the information you queried from them or the results of the task given. "
                ),
            "parameters": {
                "type": "object",
                "properties": {
                    "message_ids": {
                        "type": "array",
                        "description": (
                            "A list of message IDs to reference previous messages. "
                            "The selected messages will be included into the forked conversation."),
                        "items": {
                            "type": "string"
                        }
                    },
                    "message": {
                        "type": "string",
                        "description": "The message or an instruction to the forked agent",
                    },
                },
                "required": [],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "join",
            "description": "Join the conversation with the provided thread ID (TID)",
            "parameters": {
                "type": "object",
                "properties": {
                    "tids": {
                        "type": "array",
                        "description": (
                            "A list of thread IDs (TIDs) to join. "
                            "The agent will wait for the results of the forked conversations."),
                    },
                },
                "required": ["tids"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "status",
            "description": "Gets the status of all threads launched by this thread",
            "parameters": {
                "type": "object",
                "properties": {},
                "required": [],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "finish",
            "description": "Finishes the agent or the forked agent operation when the result is ready.",
            "parameters": {
                "type": "object",
                "properties": {
                    "message_ids": {
                        "type": "array",
                        "description": (
                            "A list of message IDs to reference previous messages. "
                            "The selected messages will be included into the response."),
                        "items": {
                            "type": "string"
                        }
                    },
                    "message": {
                        "type": "string",
                        "description": "The message to the caller agent",
                    },
                },
                "required": [],
            },
        },
    },
]


problem = """
You're a contestant on a game show. The host, Monty Hall, presents you with three doors. Behind one door is a valuable prize (like a car), and behind the other two doors are less desirable prizes (like goats).
The Setup:

You choose one door (say, Door #1)
Monty, who knows what's behind each door, opens one of the remaining doors that contains a goat (say, Door #3)
Monty then asks: "Do you want to stick with your original choice (Door #1) or switch to the remaining unopened door (Door #2)?"

The Question: What should you do to maximize your chances of winning the car?"""

task = f"Solve the following problem with two different methods:\n\n {problem}"

messages = [
    {
        "role": "system",
        "content": (
            "You are a helpful assistant that can solve problems. "
            "You can (MUST) use fork and join tools to parallelize your operation. "
            "You must parallelize your operation as much as possible to minimize the time to the final answer"
        ),
    },
    {
        "role": "user",
        "content": task,
    },
]

import json
import re

def extract_tool_calls(response_text: str) -> list[dict[str, Any]]:
    # Look for tool call patterns in the response
    tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
    matches = re.findall(tool_call_pattern, response_text, re.DOTALL)
    
    tool_calls = []
    for match in matches:
        try:
            tool_data = json.loads(match.strip())
            tool_calls.append({
                "id": f"call_{len(tool_calls)}",
                "type": "function",
                "function": tool_data,
            })
        except:
            continue
    
    return tool_calls


def remove_tool_calls(response_text: str) -> str:
    # Remove tool call patterns from the response
    tool_call_pattern = r'<tool_call>.*?</tool_call>'
    cleaned_text = re.sub(tool_call_pattern, '', response_text, flags=re.DOTALL)
    return cleaned_text.strip()


def tag_message(message: dict[str, Any], mid_length: int) -> dict[str, Any]:
    if message["role"] in ["system", "user", "assistant"]:
        message = message.copy()
        random_id = generate_random_id(mid_length)
        message["content"] = f"[{random_id}] {message['content']}"
    return message

def tag_messages(messages: list[dict[str, Any]], mid_length: int) -> list[dict[str, Any]]:
    # for every system, user and assistant message, prepend the content with [random_id]. do not alter messages as a variable
    out_messages = []
    for message in messages:
        message = tag_message(message, mid_length)
        out_messages.append(message)
    return out_messages

def extract_mid(message: dict[str, Any]) -> str | None:
    # Extract the Message ID (MID) from the message content
    if "content" in message:
        match = re.search(r'^\[([A-Za-z0-9]+)\]', message["content"])
        if match:
            return match.group(1)
    return None

import concurrent.futures
from typing import Any

class ForkManager:
    def __init__(self, 
                 llm: LLM,
                 tools: list[dict[str, Any]],
                 ):
        self.llm = llm
        self.tools = tools

        self.sampling_params = SamplingParams(max_tokens=8192, temperature=0.0)

        self.tools_functions = {
            "fork": self.fork,
            "join": self.join,
            "status": self.status,
            "finish": self.finish,
        }

        self.thread_pool = concurrent.futures.ThreadPoolExecutor()

        self.thread_records: dict[str, dict[str, Any]] = {}

        self.tid_length = 3 # Thread ID, chars
        self.tcid_length = 2 # Tool Call ID, chars
        self.mid_length = 4 # Message ID, chars

        self.max_turns = 10

    def fork(
        self,
        my_tid: str,
        messages: list[dict[str, Any]],
        message: str | None = None,
        message_ids: list[str] | None = None,
    ) -> str:
        selected_messages = []
        if len(message_ids) > 0:
            for message in messages:
                mid = extract_mid(message)
                if mid is not None and mid in message_ids:
                    selected_messages.append(message)
        if message is not None:
            selected_messages.append({"role": "user", "content": message})
        future_tid = self.submit_task(my_tid, selected_messages)
        return f"Forked TID: {future_tid}"

    def join(self, tids: list[str]) -> str:
        results: list[str] = []
        for tid in tids:
            if tid in self.thread_records:
                future: concurrent.futures.Future
                future = self.thread_records[tid]['future']
                if future is not None:
                    result = future.result()
                    results.append(f"Thread {tid} result:\n\n {result}")
                else:
                    raise ValueError(f"Thread {tid} has no future associated with it.")
            else:
                raise ValueError(f"Thread {tid} does not exist.")
        results_str = "Successfully joined threads.\n\n"
        for result in results:
            results_str += f"Thread {result}\n"
            results_str += f"{result}\n\n"
        return results_str

    def status(self) -> str:
        status_lines = []
        for tid, record in self.thread_records.items():
            parent_tid = record['parent_tid']
            child_tids = record['child_tids']
            future: concurrent.futures.Future = record['future']
            if future is not None:
                status = "Running" if not future.done() else "Finished"
                result = future.result() if future.done() else "Not yet completed"
            else:
                status = "Not started"
                result = "No result available"
            status_lines.append(
                f"Thread ID: {tid}, Parent TID: {parent_tid}, "
                f"Child TIDs: {', '.join(child_tids) if child_tids else 'None'}, "
                f"Status: {status}, Result: {result}"
            )
        if not status_lines:
            return "No threads have been created yet."
        return "\n".join(status_lines)

    def finish(self,
               messages: list[dict[str, Any]],
               message: str | None = None,
               message_ids: list[str] | None = None,
               ) -> str:

        selected_messages = []
        if len(message_ids) > 0:
            for message in messages:
                mid = extract_mid(message)
                if mid is not None and mid in message_ids:
                    selected_messages.append(message)

        result_str = ""
        for message in selected_messages:
            if message["role"] in ["system", "user", "assistant"]:
                result_str += f"{message['role'].capitalize()}: {message['content']}\n"
            elif message["role"] == "tool":
                tool_call_id = message.get("tool_call_id", "unknown")
                result_str += f"Tool Call [{tool_call_id}]: {message['content']}\n"
        
        if message is not None:
            result_str += f"Message: {message}\n"
        
        return result_str

    def submit_task(self, parent_tid: str, messages: list[dict[str, Any]]) -> str:
        new_tid = generate_random_id(self.tid_length)
        future = self.thread_pool.submit(
            self.run_agent, new_tid, messages
        )
        self.thread_records[parent_tid]['child_tids'].append(new_tid)
        self.thread_records[new_tid] = {
            "future": future,
            "parent_tid": parent_tid,
            "child_tids": [],
        }
        return new_tid

    def run_agent(self, my_tid: str, messages: list[dict[str, Any]]) -> str | None:
        for i in range(self.max_turns):
            outputs = self.llm.chat(
                messages,
                sampling_params=self.sampling_params,
                tools=self.tools)
            output = outputs[0].outputs[0].text.strip()
            content = remove_tool_calls(output)
            messages.append({"role": "assistant", "content": content})
            tool_calls = extract_tool_calls(output)
            tool_answers = []
            is_finish = False
            result_str: str | None = None
            for tool_call in tool_calls:
                func = tool_call["function"]
                tool_name = func["name"]
                if tool_name == "finish":
                    is_finish = True
                    tool_fn = self.tools_functions[tool_name]
                    result_str = tool_fn(messages=messages_forking, **tool_args)
                    break
            if is_finish:
                break

            for tool_call in tool_calls:
                if is_finish and tool_call["function"]["name"] != "finish":
                    continue
                if 'role' not in tool_call:
                    tool_call['role'] = 'tool'
                if 'id' not in tool_call:
                    tool_call['id'] = generate_random_id(self.tcid_length) # patch
                func = tool_call["function"]
                tool_name = func["name"]
                tool_args = func["arguments"]
                tool_fn = self.tools_functions[tool_name]
                if tool_name == 'fork':
                    messages_forking = messages.copy()
                    messages_forking.append(tool_call)
                    tool_answer = tool_fn(my_tid=my_tid, messages=messages_forking, **tool_args)
                else:
                    tool_answer = tool_fn(**tool_args)
                tool_answers.append(tool_answer)
            for tool_call, tool_answer in zip(tool_calls, tool_answers):
                tool_answer_message = {
                    "role": "tool",
                    "content": tool_answer,
                    "tool_call_id": tool_call['id'],
                }
                messages.append(tool_answer_message)
        return result_str

    def run_entry(self, messages):
        messages = tag_messages(messages, self.mid_length)
        new_tid = generate_random_id(self.tid_length)
        self.thread_records[new_tid] = {
            "future": None,
            "parent_tid": None,
            "child_tids": [],
        }
        whatever = self.run_agent(new_tid, messages)
        return whatever

fork_manager = ForkManager(llm, tools)

response = fork_manager.run_entry(messages)

print(response)