In [None]:
import sqlite3
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, LiteLLMModel
from dotenv import load_dotenv
import os
import json 

load_dotenv()
cohere_token = os.getenv("cohere_key")

# Define the agent
model = LiteLLMModel(
    model_id="command-r",  
    api_key=cohere_token,          
)
agent = CodeAgent(tools=[DuckDuckGoSearchTool()], model=model)

# Prepare the database and fetch data
def choose_best_algorithm(db_file="train_mTSP.sqlite3"):
    conn = sqlite3.connect(db_file)
    cursor = conn.cursor()

    cursor.execute("SELECT instance_id, nr_cities, nr_salesmen FROM instances")
    instances = cursor.fetchall()
    instances_with_columns = [
        {"instance_id": row[0], "nr_cities": row[1], "nr_salesmen": row[2]} for row in instances
    ]

    cursor.execute("SELECT instance_id, strategy, total_cost, time_taken, distance_gap FROM algorithms")
    algorithms_data = cursor.fetchall()
    algorithms_with_columns = [
        {
            "instance_id": row[0],
            "strategy": row[1],
            "total_cost": row[2],
            "time_taken": row[3],
            "distance_gap": row[4],
        }
        for row in algorithms_data
    ]

    # Prepare data for the agent
    data = {
        "instances": instances_with_columns,
        "algorithms": algorithms_with_columns,
    }
    data_str = json.dumps(data, indent=4)

    # Task for the agent
    task = f"""
    You are an AI agent tasked with selecting the best algorithm for solving each instance of the mTSP problem.
    Below is the data from the algorithms table, including column names for clarity:

    {data_str}

    You need to choose the best algorithm that should minimize total_cost, time_taken, and distance_gap from the algorithms table. 
    At the moment tha data contains the following strategies:
    1. "Greedy"
    2. "Bracnh and Cut"
    For each instance, print the selected algorithm from these and explain why it was chosen.
    """

    # Run the agent
    response = agent.run(task)
    print(response)
    
    conn.close()    
choose_best_algorithm()

KeyboardInterrupt: 