In [110]:
from dotenv import load_dotenv, find_dotenv

load_dotenv(find_dotenv())

import os
from mistralai import Mistral
from smolagents import LiteLLMModel
import duckdb

api_key = os.environ["MISTRAL_API_KEY"]
model = LiteLLMModel(model_id="mistral/mistral-medium", api_key=api_key)

client = Mistral(api_key=api_key)

conn = duckdb.connect('grid.duckdb')

In [111]:
from smolagents import tool

from database import (
    list_databases,
    list_schemas,
    list_tables,
    list_views,
    list_columns,
    list_dependencies,
)

@tool
def describe_table(table_name: str) -> str:
    """
    Return the column names & types for `grid_ops.<table_name>`.

    Args:
        table_name: The name of the table to describe.
    """
    cols = list_columns(table_name)
    return "\n".join(f"{c[6]}  ({c[11]})" for c in cols)

@tool
def query_db(query: str) -> str:
    """
    Execute arbitrary SQL against DuckDB and return results as text.

    Args:
        query: The SQL query to execute. DuckDB will be used as the database.
    """
    rows = conn.sql(query)
    return rows

@tool
def tool_list_databases() -> str:
    """lists all accessible databases"""
    return list_databases()

@tool
def tool_list_schemas() -> str:
    """lists all schemas in the current database"""
    return list_schemas()

@tool
def tool_list_tables() -> str:
    """lists all tables in the current database"""
    return list_tables()

@tool
def tool_list_views() -> str:
    """lists all views in the current database"""
    return list_views()

@tool
def tool_list_columns(table_name: str, schema: str = 'grid_ops') -> str:
    """lists all columns for a given table
    Args:
        table_name: The name of the table to describe.
        schema: The schema in which the table lives.
    """
    return list_columns(table_name, schema)

@tool
def tool_list_dependencies() -> str:
    """lists all dependencies between objects in the current database"""
    return list_dependencies()

In [112]:
tool_list = [
    describe_table,
    query_db,
    tool_list_databases,
    tool_list_schemas,
    tool_list_tables,
    tool_list_views,
    tool_list_columns,
    tool_list_dependencies,
]

from smolagents import CodeAgent

agent = CodeAgent(
    tools=tool_list,
    model=model,
)

In [113]:
eval_questions: list[dict] = [
    {
        "question": "Which substations have a capacity greater than 400 MW, and what are their names and locations?",
        "answer": "The substations with a capacity greater than 400 MW is Amsterdam Zuid Substation."
    },
    {
        "question": "For each substation, what is the total length of all transmission lines connected to it?",
        "answer":"""Substation:
        - Amsterdam Zuid Substation: 330 km
        - Rotterdam Noord Substation: 350 km
        - Utrecht Lunetten Substation: 380 km"""
    },
    {
        "question": "Show the complete maintenance history (date and description) for transmission line ID 2.",
        "answer": """Maintenance history for transmission line ID 2:
        - 2025-03-05: Line inspection
        - 2025-06-01: Insulator cleaning"""
    },
    {
        "question": "List all generators of type 'Solar' and the names of the substations they’re connected to.",
        "answer": """Generators of type 'Solar' and the names of the substations they’re connected to:
        - SP-EEM-01: Rotterdam Noord Substation"""
    },
    {
        "question": "What is the total generation capacity (sum of max_output_mw) connected to each substation?",
        "answer": """Total generation capacity connected to each substation:
        - Amsterdam Zuid Substation: 180 MW
        - Rotterdam Noord Substation: 75 MW
        - Utrecht Lunetten Substation: 0 MW"""
    }
]

In [114]:
generated_answers: list[dict] = []

for qa in eval_questions:
    resp = agent.run(qa["question"])
    generated_answers.append({"question": qa["question"], "model": model.model_id, "generated_answer": resp})

In [115]:
generated_answers

[{'question': 'Which substations have a capacity greater than 400 MW, and what are their names and locations?',
  'model': 'mistral/mistral-medium',
  'generated_answer': 'The substation with a capacity greater than 400 MW is Amsterdam Zuid Substation located in Amsterdam.'},
 {'question': 'For each substation, what is the total length of all transmission lines connected to it?',
  'model': 'mistral/mistral-medium',
  'generated_answer': ┌───────────────┬─────────────────────────────┬─────────────────┐
  │ substation_id │            name             │ total_length_km │
  │     int32     │           varchar           │     double      │
  ├───────────────┼─────────────────────────────┼─────────────────┤
  │             3 │ Utrecht Lunetten Substation │           380.0 │
  │             2 │ Rotterdam Noord Substation  │           350.0 │
  │             1 │ Amsterdam Zuid Substation   │           330.0 │
  └───────────────┴─────────────────────────────┴─────────────────┘},
 {'question': 

### Asynchronous evaluation of generated answers

In [116]:
import asyncio
import nest_asyncio
from pydantic import BaseModel
from mistralai import Mistral

nest_asyncio.apply()

client = Mistral(api_key=api_key)
model = "mistral-small"

class Score(BaseModel):
    score: int

async def evaluate_rag(query: str, model_answer: str, generated_answer: str):
    chat_response = await client.chat.parse_async(
        model=model,
        messages=[
            {
                "role": "system",
                "content": (
                    "You are a judge for evaluating a SQL Agent. "
                    "Evaluate the correctness of the generated answer based on model answer."
                    "If the generated answer is equal to the model answer, return 1. If not, return 0."
                )
            },
            {
                "role": "user",
                "content": f"Query: {query}\nModel Answer: {model_answer}\nGenerated Answer: {generated_answer}"
            },
        ],
        response_format=Score,
        temperature=0
    )
    return chat_response.choices[0].message.parsed


scores = []
async def evaluate_all():
    for i, qa in enumerate(eval_questions):
        generated_answer = generated_answers[i]["generated_answer"]
        score = await evaluate_rag(qa["question"], qa["answer"], generated_answer)
        scores.append(score)

    return scores

scores = asyncio.run(evaluate_all())

percentage_correct = sum([s.score for s in scores]) / len(scores)
print(f"fraction correct: {percentage_correct}")

fraction correct: 1.0


### Embedding comparison as a metric

In [122]:
embedding_model = "mistral-embed"

gen_answers_to_embed = [str(g["generated_answer"]) for g in generated_answers]

generated_answers_embedded = client.embeddings.create(
    model=embedding_model,
    inputs=gen_answers_to_embed,
)

answers_to_embed = [a["answer"] for a in eval_questions]
answers_embedded = client.embeddings.create(
    model=embedding_model,
    inputs=answers_to_embed,
)

import numpy as np

for i, (ga_embed, a_embed) in enumerate(zip(generated_answers_embedded.data, answers_embedded.data)):
    print("--------------------")
    print(f"\n\nGenerated answer: {generated_answers[i]["generated_answer"]}")
    print(f"\n\nActual answer: {eval_questions[i]["answer"]}")
    print(f"\n\nDot product of embeddings: {np.dot(ga_embed.embedding, a_embed.embedding)}")

    print("average dot product: ", np.mean([np.dot(ga_embed.embedding, a_embed.embedding) for ga_embed, a_embed in zip(generated_answers_embedded.data, answers_embedded.data)]))

--------------------


Generated answer: The substation with a capacity greater than 400 MW is Amsterdam Zuid Substation located in Amsterdam.


Actual answer: The substations with a capacity greater than 400 MW is Amsterdam Zuid Substation.


Dot product of embeddings: 0.9936256111321313
average dot product:  0.8941219688904084
--------------------


Generated answer: ┌───────────────┬─────────────────────────────┬─────────────────┐
│ substation_id │            name             │ total_length_km │
│     int32     │           varchar           │     double      │
├───────────────┼─────────────────────────────┼─────────────────┤
│             3 │ Utrecht Lunetten Substation │           380.0 │
│             2 │ Rotterdam Noord Substation  │           350.0 │
│             1 │ Amsterdam Zuid Substation   │           330.0 │
└───────────────┴─────────────────────────────┴─────────────────┘



Actual answer: Substation:
        - Amsterdam Zuid Substation: 330 km
        - Rotterdam Noord 

##### Conclusion: Embedding overlap comparison as a similarity metric is only useful for short answers