In [None]:
import os
import json
import time

from core.prompts import (
    create_manager_prompt,
    create_data_loader_prompt,
    create_data_processor_prompt,
    create_model_designer_prompt,
    create_trainer_prompt,
)
from core.tools import (
    list_files,
    read_files,
    preview_file_content,
    tree,
    write_to_file,
    copy_file,
    run_script,
)

In [None]:
from langchain_openai import ChatOpenAI
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from dotenv import load_dotenv

callback_manager = OpenAICallbackHandler()

model_name = "google/gemini-2.5-flash" 
prompt_cost_per_million = 0.15
completion_cost_per_million = 0.60
class_name = "bottle"


In [None]:

load_dotenv()
api_key = "sk-xxx" # or os.environ["OPENAI_API_KEY"]
base_url = "base_url" # or os.environ["OPENAI_API_BASE"]

llm = ChatOpenAI(
    openai_api_key=api_key,
    model_name=model_name,
    base_url=base_url,
    callbacks=[callback_manager],
    extra_body={"enable_thinking": False},
)


task_card = json.load(open(os.path.join("TaskCard.json"), "r"))

work_path = "workspace"
knowledge_path = "knowledge" 

recursion_limit = 100

DRAW_AGRNT = False
RUN_AGENT = False

In [None]:
print("Task card test:", task_card)  # test task card
print("LLM test:", llm.invoke("hello! Who are you?",  extra_body={"enable_thinking": False},))  # test llm

In [None]:
from langgraph.prebuilt import create_react_agent
from core.utils import pretty_print_messages


def build_worker(name, llm, tools, prompt):
    agent = create_react_agent(model=llm, tools=tools, prompt=prompt, name=name)
    return agent
from core.utils import timeout

@timeout(600)
def run_agent(agent, prompt, recursion_limit=100):
    for chunk in agent.stream(
        {"messages": [{"role": "user", "content": prompt}]},
        {"recursion_limit": recursion_limit},
    ):
        pretty_print_messages(chunk)

In [None]:
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod


def draw_graph(agent):
    display(
        Image(
            agent.get_graph().draw_mermaid_png(
                draw_method=MermaidDrawMethod.API,
            )
        )
    )

In [None]:
data_processor_agent = build_worker(
    "data_processor",
    llm,
    tools=[
        list_files,
        read_files,
        preview_file_content,
        tree,
        write_to_file,
        copy_file,
        run_script,
    ],
    prompt=create_data_processor_prompt(work_path, task_card, knowledge_path),
)

# draw_graph(data_processor_agent)

if RUN_AGENT:
    run_agent(data_processor_agent, "now you should make a python script and run it.")

In [None]:
data_loader_agent = build_worker(
    "data_loader",
    llm,
    tools=[
        read_files,
        preview_file_content,
        write_to_file,
        copy_file,
        run_script,
    ],
    prompt=create_data_loader_prompt(work_path, task_card, knowledge_path),
)

if DRAW_AGRNT:
    draw_graph(data_loader_agent)
if RUN_AGENT:
    run_agent(data_loader_agent, "now you should make a python class script and test it.")

In [None]:
model_designer_agent = build_worker(
    "model_designer",
    llm,
    tools=[
        read_files,
        preview_file_content,
        write_to_file,
        copy_file,
        run_script,
    ],
    prompt=create_model_designer_prompt(work_path, task_card, knowledge_path),
)

if DRAW_AGRNT:
    draw_graph(model_designer_agent)
if RUN_AGENT:
    run_agent(model_designer_agent, "now you should make a python class script and test it.")

In [None]:
trainer_agent = build_worker(
    "trainer",
    llm,
    tools=[
        read_files,
        preview_file_content,
        write_to_file,
        copy_file,
        run_script,
    ],
    prompt=create_trainer_prompt(work_path, task_card, knowledge_path),
)

if DRAW_AGRNT:
    draw_graph(trainer_agent)
if RUN_AGENT:
    run_agent(trainer_agent, "now you should train the model or optimize the model.")


In [None]:
from langgraph_supervisor import create_supervisor

all_worker_agents = [
    data_processor_agent,
    data_loader_agent,
    model_designer_agent,
    trainer_agent,
]

manager = create_supervisor(
    model=llm,
    tools=[
        list_files,
        read_files,
        preview_file_content,
        tree,
        run_script,
    ],
    agents=[
        data_processor_agent,
        data_loader_agent,
        model_designer_agent,
        trainer_agent,
    ],
    prompt=create_manager_prompt(
        work_path,
        task_card,
        agent_names=[agent.name for agent in all_worker_agents],
    ),
    add_handoff_back_messages=True,
    output_mode="full_history",
    supervisor_name="manager",
).compile()

# draw_graph(manager)

In [None]:
start_time = time.time()
try:
    run_agent(manager, "now you should do the task.")
except TimeoutError as e:
    print(f"TimeoutError: {e}")
except Exception as e:
    print(f"Exception: {e}")
finally:
    end_time = time.time()

In [None]:
from core.utils import calculate_llm_cost

# === STATS ===
print(f"Time used: {end_time - start_time:.2f}s")
total_prompt_tokens = callback_manager.prompt_tokens
total_completion_tokens = callback_manager.completion_tokens
total_cost = calculate_llm_cost(
    total_prompt_tokens,
    total_completion_tokens,
    prompt_cost_per_million=prompt_cost_per_million,
    completion_cost_per_million=completion_cost_per_million,
)
print(f"Total Tokens Used: {total_prompt_tokens + total_completion_tokens}")
print(f"Prompt Tokens: {total_prompt_tokens}")
print(f"Completion Tokens: {total_completion_tokens}")
print(f"Total Cost (USD): ${total_cost:.6f}")

callback_manager.prompt_tokens = 0
callback_manager.completion_tokens = 0