In [None]:
import re
import json
import sympy as sp
import requests
from typing import List, Dict, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline, PreTrainedModel, PreTrainedTokenizerFast
from torch.optim import Adam
import torch
from torch.distributions import Categorical

In [None]:
load_dotenv()
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

LLM_MODEL = "gemma-2-27b"  # Placeholder model
REWARD_MODEL = "gemma-2-27b"  # Placeholder model
SEARCH_API_KEY = "YOUR_SEARCH_API_KEY"
MAX_HOTPOT_STEPS = 5
MAX_GSM8K_STEPS = 10

query_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
query_model = AutoModelForCausalLM.from_pretrained(LLM_MODEL)
query_pipe = pipeline("text-generation", model=query_model, tokenizer=query_tokenizer, device=device)

reward_tokenizer = AutoTokenizer.from_pretrained(REWARD_MODEL)
reward_model = AutoModelForCausalLM.from_pretrained(REWARD_MODEL)

# Should be dictionary of form {tool_name: tool_function}
tools_list = []

In [None]:
def query_model(prompt):
    return query_pipe(prompt)[0]["generated_text"]
    

In [None]:
# Prompting templates
def tool_prompt():
    return f'''You have access to two tools. To execute some Python code, please wrap the code you want to execute in tags like this: <exec>[CODE]</exec>. The code should be fully self-contained and will be executed in a separate Python file as a main script. To query an external reasoning model with advanced question-answering capabilities, please wrap the query you want to execute in tags like this: <query>[QUERY]</query>. You may only use at most one tool call in your output; the output of the tool call will be appended to the prompt the next time you are queried'''

def trajectory_history_prompt(trajectory):
    traj = ""
    for i in range(len(trajectory)):
        traj += f"Attempt {i + 1}:\n{trajectory[i]}\n"
    return f'''The entire history of your previous outputs when generating this problem, its test cases, and solution, is presented below.\n{traj}'''

def mutation_prompt(problem):
    return f'''Please increase the difficulty of the given programming test question a bit. Do not provide any hints, solutions or outputs. Only one new instruction is allowed. Additionally, generate rationales for what new concepts your problem is designed to test, and how these make it harder. Ensure that you include the <|New Question Begin|>, <|New Question End|>, <|Rationales Begin|>, <|Rationales End|> tags as depicted.
{tool_prompt()}
Original Question: {problem}

### Prompt Template
<|New Question Begin|>
[New Question]
<|New Question End|>
<|Rationales Begin|>
[Rationales for New Question]
<|Rationales End|>
'''

def generate_test_prompt(problem):
    return f'''Please generate unit tests that can be used to verify a solution to this problem. The entire chat history of your previous attempts to generate questions and unit tests is presented below in the "Chat History" section, along with the output of running your solution against your tests in a code execution environment. Please modify only your tests and/or solution to be more correct.

Question: {problem}

### Prompt Template
<|New Question Begin|>
[New Question]
<|New Question End|>
<|Rationales Begin|>
[Rationales for New Question]
<|Rationales End|>
'''

In [None]:
# Function to get AI response
def get_openai_response(prompt, model="gpt-4o", log_file="./api_log_swirl.txt"):
    try:
        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}]
        )
        msg = response.choices[0].message.content.strip()

        if log_file != None:
            with open(log_file, "a") as f:
                f.write("PROMPT:\n\n" + prompt + "\n\n" + "RESPONSE\n\n" + msg + "\n\n")

        return msg
    except Exception as e:
        return f"Error: {e}"

def parse_tag(text, tag):
    pattern = f"<\|{tag} Begin\|>(.*?)<\|{tag} End\|>"
    return re.findall(pattern, text, re.DOTALL)

def parse_tool_calls():
    return

def mutate(problem):
    output = query_model(mutation_prompt(problem))
    parse_tag()
    

In [None]:
# Tool calls
def execute_tool(tool, payload):
    if tool == 'query':
        
    elif tool == 'execute'

    else:
        raise(f"Tried to use nonexistent tool {tool}")

In [None]:
def generate_trajectory(seed_problem, api_model="gpt-4o"):


def generate_mutation_trajectory(seed_problem, num_iterations):
    trajectory = []
    

    