In [35]:
import os
import sys
import requests
from dotenv import load_dotenv
from typing import Optional, List, Dict, Any
import json

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

import sqlite3

import anthropic

# Initialization: Using Claude & Querying the Database

### Querying the Database and Saving Results to Use for Eval

In [67]:
def execute_sql_query(db_path, query, test_mode=False):
    """
    Execute a SQL query against a SQLite database and return the results.
    
    Args:
        db_path (str): Path to the SQLite database file
        query (str): SQL query to execute
    
    Returns:
        list: Query results as a list of rows
        None: If there's an error executing the query
    """
    try:
        conn = sqlite3.connect(db_path)
        cursor = conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        conn.close()
        return results
    except sqlite3.Error as e:
        if test_mode:
            print(f"Database error: {e}")
        return None

In [53]:
# Test Drive SQL Queries
query = """SELECT full_name FROM team ORDER BY year_founded ASC LIMIT 5"""
db_path = "nba.sqlite"
results = execute_sql_query(db_path, query, test_mode=True)
results

[('Boston Celtics',),
 ('Golden State Warriors',),
 ('New York Knicks',),
 ('Los Angeles Lakers',),
 ('Sacramento Kings',)]

In [59]:
def run_queries_and_save_answers(filename) -> List[Dict[str, Any]]:
    """
    Run SQL queries from a JSON file and save the answers to a new JSON file with the query results added
    
    Args:
        path_to_json (str): Path to the JSON file containing SQL queries
        output_path (str): Path to the output JSON file
    """
    input_path = os.path.join('data', 'input_raw', filename)
    output_path = os.path.join('data', 'input', filename)

    with open(input_path, 'r') as f:
        queries = json.load(f)
    
    all_answers = []
    
    for query in queries:
        query_text = query['sql']
        answer = execute_sql_query(db_path, query_text, test_mode=True)
        if answer is not None:
            query['answer'] = answer
            all_answers.append(query)
        else:
            print(f"Error executing query: {query_text}")
    
    with open(output_path, 'w') as f:
        json.dump(all_answers, f, indent=4)

    return all_answers

In [63]:
json_files = os.listdir(os.path.join('data', 'input_raw'))
for filename in json_files:
    if filename.endswith('.json'):
        print(f"Running queries from {filename}")
        gt_answers= run_queries_and_save_answers(filename)
        print(f"{len(gt_answers)} Answers saved to {filename}\n\n")

Running queries from ground_truth_data.json
98 Answers saved to ground_truth_data.json


Running queries from synthetic_data.json
160 Answers saved to synthetic_data.json




### Claude API Initialization and Test

In [21]:
# Read the .env file
load_dotenv()
CLAUDE_API_KEY = os.getenv("CLAUDE_API_KEY")

In [140]:
# Class to interact with the Claude API
# I used Claude to write the bulk of this class, and added a method at the end to call teh completion method and get the text from the response
class ClaudeAPI:
    def __init__(self, api_key: Optional[str] = None, model: str = "claude-3-7-sonnet-20250219"):
        """
        Initialize the Claude API wrapper.
        
        Args:
            api_key: Your Anthropic API key. If not provided, will look for ANTHROPIC_API_KEY env variable
            model: The Claude model to use (defaults to Claude 3.7 Sonnet)
        """
        self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
        if not self.api_key:
            raise ValueError("API key must be provided or set as ANTHROPIC_API_KEY environment variable")
        
        self.client = anthropic.Anthropic(api_key=self.api_key)
        self.model = model
        
    def completion(self, 
            user_prompt: str, 
            system_prompt: Optional[str] = None,
            max_tokens: int = 1024,
            temperature: float = 0.0, # Turn temperature down to 0.0 for deterministic results
            **kwargs) -> Dict[Any, Any]:
        """
        Send a completion request to Claude with system and user prompts.
        
        Args:
            user_prompt: The user's input prompt
            system_prompt: Optional system prompt to set context/instructions
            max_tokens: Maximum number of tokens to generate
            temperature: Controls randomness (0.0 to 1.0)
            **kwargs: Additional parameters to pass to the API
            
        Returns:
            The complete response from the Claude API
        """
        messages = [{
            "role": "user",
            "content": user_prompt
        }]
        
        # Add the system message if provided
        if not system_prompt:
            system_prompt = ""
        
        # Call the Claude API
        response = self.client.messages.create(
            model=self.model,
            messages=messages,
            system=system_prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            **kwargs
        )
        
        return response
    
    def get_response_text(self, response: Dict[Any, Any]) -> str:
        """
        Extract just the text content from a Claude API response.
        
        Args:
            response: The response from the Claude API
            
        Returns:
            The text content of Claude's response
        """
        if hasattr(response, "content"):
            contents = response.content
            if contents and len(contents) > 0:
                return contents[0].text
        
        return ""

    def generate_response(self, 
            user_prompt: str, 
            system_prompt: Optional[str] = None,
            max_tokens: int = 1024,
            temperature: float = 0.7,
            **kwargs) -> str:
        """
        Generate a response to a query using the Claude API.
        
        Args:
            user_prompt: The user's input prompt
            system_prompt: Optional system prompt to set context/instructions
            max_tokens: Maximum number of tokens to generate
            temperature: Controls randomness (0.0 to 1.0)
            **kwargs: Additional parameters to pass to the API
            
        Returns:
            The text response from the Claude API
        """
        response = self.completion(user_prompt, system_prompt, max_tokens, temperature, **kwargs)
        return self.get_response_text(response)

In [141]:
# Read a txt version of a prompt
def read_prompt_simple(question: str, path_to_prompt: str) -> str:
    """
    Read the prompt from a file and format it with the given question.
    """
    if not os.path.exists(path_to_prompt):
        raise FileNotFoundError(f"Prompt file not found at {path_to_prompt}")
    if not path_to_prompt.endswith('.txt'):
        raise ValueError("Prompt file must be a .txt file")

        
    with open(path_to_prompt, 'r') as f:
        prompt = f.read()
    return prompt.format(question=question)

In [142]:
def read_prompt_json(question: str, path_to_prompt: str) -> Dict[str, str]:
    """
    Read the prompt from a JSON file and format it with the given question.
    """
    if not os.path.exists(path_to_prompt):
        raise FileNotFoundError(f"Prompt file not found at {path_to_prompt}")
    if not path_to_prompt.endswith('.json'):
        raise ValueError("Prompt file must be a .json file")

    with open(path_to_prompt, 'r') as f:
        prompt = json.load(f)

    # Read the schema from the database
    schema = get_database_schema(db_path)

    if 'system' in prompt:
        system_prompt = prompt['system'].format(question=question, database_schema=schema)
    else:
        system_prompt = None

    if 'user' in prompt:
        user_prompt = prompt['user'].format(question=question)
    else:
        user_prompt = question

    return {
        'system': system_prompt,
        'user': user_prompt
    }

In [143]:
# Initialize the Claude API Client
claude_api = ClaudeAPI(api_key=CLAUDE_API_KEY)

In [144]:
# Test Drive
prompt = read_prompt_json("How many teams are in the NBA?", "prompts/solution_prompt.json")
system_prompt = prompt['system']
user_prompt = prompt['user']
response = claude_api.generate_response(user_prompt, system_prompt)
print("Response:", response)

Response: SELECT COUNT(*) FROM team


### Evaluation Harness & Testing Framework

In [164]:
def test_eq_results_from_query(answer_results: List[Any], test_results: List[Any]) -> bool:
    """
    Compare the results from a SQL query with the results from the Claude API. Each row should be compared
    to the corresponding row in the other result set. Column order should not matter
    
    Args:
        answer_results: the results of running the answer key sql query
        test_results: The results from claude generated sql query
        
    Returns:
        bool: True if the results match, False otherwise
    """
    if len(answer_results) != len(test_results):
        return False
    
    for answer_row in answer_results:
        found_row = False
        for test_row in test_results:
            if set(answer_row).issubset(set(test_row)):
                found_row = True
                break
        if not found_row:
            return False
    return True
            

In [188]:
class TestFramework:
    def __init__(self, test_name: str, db_path: str, api_client: ClaudeAPI, prompt_names: List[str], prev_df = None):
        """
        Initialize the TestFramework with a database path, API client, and a list of prompt files
        This framework will run a series of tests for each prompt in the test cases, try to run the query from the prompt, and track for accuracy to the answer
        
        Args:
            db_path: Path to the SQLite database
            api_client: Instance of the ClaudeAPI client
        """
        self.test_name = test_name
        self.db_path = db_path
        self.api_client = api_client
        self.prompt_names = prompt_names

        self.df_cols = ['prompt_method', 'input_file', 'natural_language', 'sql', 'type', 'answer', 'claude_response', 'claude_query_result', 'pass']
        if prev_df is not None and isinstance(prev_df, pd.DataFrame) and list(prev_df.columns) == self.df_cols:
            self.results_df = prev_df
        else:
            self.results_df = pd.DataFrame(columns=['prompt_method', 'input_file' 'natural_language', 'sql', 'type', 'answer', 'claude_response', 'claude_query_result', 'pass'])
        

    def run_tests(self) -> pd.DataFrame:
        """
        Run a series of tests against the database and API.
            
        Returns:
            A dataframe with the results of the tests
        """
        input_files = os.listdir(os.path.join('data', 'input'))

        test_results = []

        # Loop over prompt files
        for prompt in self.prompt_names:
            prompt_path = os.path.join('prompts', prompt) + '.json'
            
            if not os.path.exists(prompt_path):
                print(f"Prompt file not found at {prompt_path}")
                continue

            # Loop over input files
            for input_file in input_files:
                if not input_file.endswith('.json'):
                    continue

                input_path = os.path.join('data', 'input', input_file)
                
                with open(input_path, 'r') as f:
                    queries = json.load(f)

                # Loop over queries
                for query in queries:
                    query_text = query['sql']
                    natural_language = query['natural_language']
                    answer = query['answer']
                    type_val = query['type']

                    # If a combination of prompt file, input file and natural language is already in the results, skip it
                    if self.results_df.shape[0] > 0:
                        if self.results_df[(self.results_df['prompt_method'] == prompt) &
                                       (self.results_df['input_file'] == input_file.replace('.json', '')) &
                                       (self.results_df['natural_language'] == natural_language)].shape[0] > 0:
                            #Get the matching row from the results_df and append it to the test_results
                            matching_row = self.results_df[(self.results_df['prompt_method'] == prompt) &
                                                           (self.results_df['input_file'] == input_file.replace('.json', '')) &
                                                           (self.results_df['natural_language'] == natural_language)]

                            gen_sql = matching_row['claude_response'].values[0]
                            gen_query_result = matching_row['claude_query_result'].values[0]
                            pass_test = matching_row['pass'].values[0]
                            # Append results to the dataframe   
                            test_results.append({
                                'prompt_method': prompt,
                                'input_file': input_file.replace('.json', ''),
                                'natural_language': natural_language,
                                'sql': query_text,
                                'type': type_val,
                                'answer': answer,
                                'claude_response': gen_sql,
                                'claude_query_result': gen_query_result,
                                'pass': pass_test
                            })
                            continue

                    # Generate SQL Query from the Natural Language using Claude)
                    prompt_dict = read_prompt_json(natural_language, prompt_path)
                    system_prompt = prompt_dict['system']
                    user_prompt = prompt_dict['user']
                    
                    # Generate SQL query using Claude
                    try:
                        gen_sql = self.api_client.generate_response(user_prompt, system_prompt)
                    except Exception as e:
                        print(f"\tError generating SQL query: {e}")
                        gen_sql = None
                    # Check if the generated SQL query is valid
                    if not gen_sql:
                        print(f"Error generating SQL query for: {natural_language}")
                        continue
                    
                    # Run the SQL query
                    gen_query_result = execute_sql_query(self.db_path, gen_sql)

                    if gen_query_result is None:
                        pass_test = False

                    else:
                        pass_test = test_eq_results_from_query(answer, gen_query_result)

                    # Append results to the dataframe
                    test_results.append({
                        'prompt_method': prompt,
                        'input_file': input_file.replace('.json', ''),
                        'natural_language': natural_language,
                        'sql': query_text,
                        'type': type_val,
                        'answer': answer,
                        'claude_response': gen_sql,
                        'claude_query_result': gen_query_result,
                        'pass': pass_test
                    })
                    print(f"\tPass: {pass_test}")

        # Convert results to DataFrame
        self.results_df = pd.DataFrame(test_results)
        return self.results_df

    
    def print_accuracy(self) -> None:
        """
        Print the accuracy of the tests by prompt method. Also print the accuracy by input file by prompt method
        """
        accuracy_df = self.results_df.groupby('prompt_method').agg({'pass': 'mean'}).reset_index()
        accuracy_df['pass'] = accuracy_df['pass'] * 100
        accuracy_df.sort_values(by='pass', ascending=False, inplace=True)
        accuracy_df.reset_index(drop=True, inplace=True)
        accuracy_df['pass'] = accuracy_df['pass'].round(2)
        accuracy_df['pass'] = accuracy_df['pass'].astype(str) + '%'
        accuracy_df.rename(columns={'pass': 'accuracy'}, inplace=True)
        print("\nAccuracy by Prompt Method:")
        print(accuracy_df)

        print("\n\nAccuracy by Input File:")
        input_files = self.results_df['input_file'].unique()
        for input_file in input_files:
            temp_df = self.results_df[self.results_df['input_file'] == input_file]
            accuracy_df = temp_df.groupby('prompt_method').agg({'pass': 'mean'}).reset_index()
            accuracy_df['pass'] = accuracy_df['pass'] * 100
            accuracy_df.sort_values(by='pass', ascending=False, inplace=True)
            accuracy_df.reset_index(drop=True, inplace=True)
            accuracy_df['pass'] = accuracy_df['pass'].round(2)
            accuracy_df['pass'] = accuracy_df['pass'].astype(str) + '%'
            accuracy_df.rename(columns={'pass': 'accuracy'}, inplace=True)
            print(f"\n{input_file} Accuracy by Prompt Method:")
            print(accuracy_df)

    def plot_accuracy_by_prompt(self) -> None:
        """
        Plot the accuracy of the tests by prompt method.
        """
        if self.results_df.empty:
            print("No test results to plot.")
            return

        # Group by prompt method and calculate accuracy
        accuracy_df = self.results_df.groupby('prompt_method').agg({'pass': 'mean'}).reset_index()
        accuracy_df['pass'] = accuracy_df['pass'] * 100
        accuracy_df.sort_values(by='pass', ascending=False, inplace=True)
        accuracy_df.reset_index(drop=True, inplace=True)
        accuracy_df['pass'] = accuracy_df['pass'].round(2)
        accuracy_df['pass'] = accuracy_df['pass'].astype(str) + '%'
        accuracy_df.rename(columns={'pass': 'accuracy'}, inplace=True)

        # Plotting
        plt.figure(figsize=(10, 6))
        plt.bar(accuracy_df['prompt_method'], accuracy_df['accuracy'], color='skyblue')
        plt.xlabel('Prompt Method')
        plt.ylabel('Accuracy (%)')
        plt.title('Accuracy of Tests by Prompt Method')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()

        # Save the plot
        output_path = os.path.join('images', self.test_name, 'total_accuracy_plot.png')
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        plt.savefig(output_path)

        print(f"Accuracy plot saved to {output_path}")

        # Loop over the input files to show accuracy by input file
        input_files = self.results_df['input_file'].unique()
        for input_file in input_files:
            temp_df = self.results_df[self.results_df['input_file'] == input_file]
            accuracy_df = temp_df.groupby('prompt_method').agg({'pass': 'mean'}).reset_index()
            accuracy_df['pass'] = accuracy_df['pass'] * 100
            accuracy_df.sort_values(by='pass', ascending=False, inplace=True)
            accuracy_df.reset_index(drop=True, inplace=True)
            accuracy_df['pass'] = accuracy_df['pass'].round(2)
            accuracy_df['pass'] = accuracy_df['pass'].astype(str) + '%'
            accuracy_df.rename(columns={'pass': 'accuracy'}, inplace=True)

            # Plotting
            plt.figure(figsize=(10, 6))
            plt.bar(accuracy_df['prompt_method'], accuracy_df['accuracy'], color='skyblue')
            plt.xlabel('Prompt Method')
            plt.ylabel('Accuracy (%)')
            plt.title('Accuracy of Tests by Prompt Method - ' + input_file)
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.show()

            # Save the plot
            output_path = os.path.join('images', self.test_name, f'{input_file}_accuracy_plot.png')
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            plt.savefig(output_path)

    def save_results(self, output_path: str = None) -> None:
        """
        Save the test results to a CSV file.
        
        Args:
            output_path: Path to save the CSV file
        """
        if self.results_df.empty:
            print("No test results to save.")
            return

        if output_path is None:
            output_path = os.path.join('data', 'output', f'{self.test_name}_results.csv')

        # Ensure the output directory exists
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        self.results_df.to_csv(output_path, index=False)
        print(f"Test results saved to {output_path}")
        return output_path
        

# Test: Original Prompt w/o Context

Let's first see how well the initial prompt does

In [172]:
test1_prompts = ['original_prompt']
test1_framework = TestFramework('original_prompt_results', db_path, claude_api, test1_prompts)

In [173]:
test1_results = test1_framework.run_tests()
test1_framework.save_results()
test1_framework.print_accuracy()

Test results saved to data/output/original_prompt_results_results.csv

Accuracy by Prompt Method:
     prompt_method accuracy
0  original_prompt     0.0%


Accuracy by Input File:

ground_truth_data Accuracy by Prompt Method:
     prompt_method accuracy
0  original_prompt     0.0%

synthetic_data Accuracy by Prompt Method:
     prompt_method accuracy
0  original_prompt     0.0%


In [124]:
# 0% Is not great! Lets see a few examples
test1_results[['natural_language','claude_response','claude_query_result','pass']].head()

Unnamed: 0,natural_language,claude_response,claude_query_result,pass
0,How many teams are in the NBA?,"To convert this question to SQL, I need to wri...",,False
1,What are the 5 oldest teams in the NBA?,"```sql\nSELECT team_name, founded_year\nFROM n...",,False
2,List all teams from California,```sql\nSELECT *\nFROM teams\nWHERE state = 'C...,,False
3,What's the total number of games played?,```sql\nSELECT COUNT(*) AS total_games_played\...,,False
4,What's the highest scoring game?,"To convert this question to SQL, I need to mak...",,False


We can see that by not steering Claude to write SQL or give it any context to the relationships and tables in the database, we are not getting any valid SQL queries back from Claude. Every test is failing as a result.

# Solution: Use the System Prompt to Embed Data & SQL Context to the Model

To solve for this issue, we're going to provide a few pieces of context to Claude in the system prompt each time you make a query:
1. Steer it to only ever respond with valid SQL. This way, you can rely on claude to format it's response back to you, rather than needing to write custom parsing scripts to get the SQL results you desire
2. Give Claude a representation of your database schema so that it knows which tables it can query and for which columns. I've made sure we can embed other databases you have in a similar fashion if needed in this implementation.
3. Provide examples to Claude so that it can match queries that have worked for you in the past.

### Retrieve the Database Schema to Embed in the System Prompt

In [157]:
def get_database_schema(db_path: str) -> Dict[str, List[Dict[str, Any]]]:
    """
    Create an XML representation of the SQLite DB
    """
    # Connect to the database
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    # Query to get all tables
    cursor.execute("SELECT * FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()
    
    schema_info = {}
    
    # For each table, get column information
    for table in tables:
        table_name = table[1]
        pragma = f"PRAGMA table_info('{table_name}');"
        cursor.execute(pragma)
        columns = cursor.fetchall()
        
        # Store column information
        schema_info[table_name] = [
            {
                "name": col[1],
                "type": col[2],
                "notnull": col[3],
                "default_value": col[4],
                "is_primary_key": col[5]
            }
            for col in columns
        ]
    
    conn.close()

    # Format the schema information
    for table, columns in schema_info.items():
        formatted_columns = []
        for col in columns:
            formatted_columns.append(f"{col['name']} ({col['type']})")
            if col['notnull']:
                formatted_columns[-1] += " NOT NULL"
            if col['is_primary_key']:
                formatted_columns[-1] += " PRIMARY KEY"
            if col['default_value'] is not None:
                formatted_columns[-1] += f" DEFAULT {col['default_value']}"
        schema_info[table] = ", ".join(formatted_columns)

    tables = list(schema_info.keys())

    ret_str = f"\t<tables>\n\t\t[{', '.join(tables)}]\n\t</tables>\n"
    ret_str += "\t<schema>\n"
    for table, columns in schema_info.items():
        ret_str += f"\t\t<{table}>\n"
        ret_str += f"\t\t\t{columns}\n"
        ret_str += f"\t\t</{table}>\n"
    ret_str += "</schema>\n"

    return ret_str

### Run Tests for the new Prompt

In [189]:
solution_prompts = ['original_prompt', 'solution_prompt']
solution_framework = TestFramework('solution_results', db_path, claude_api, solution_prompts, prev_df=test1_results)

In [190]:
solution_results = solution_framework.run_tests()

Skipping original_prompt - ground_truth_data.json - How many teams are in the NBA? as it has already been tested
Skipping original_prompt - ground_truth_data.json - What are the 5 oldest teams in the NBA? as it has already been tested
Skipping original_prompt - ground_truth_data.json - List all teams from California as it has already been tested
Skipping original_prompt - ground_truth_data.json - What's the total number of games played? as it has already been tested
Skipping original_prompt - ground_truth_data.json - What's the highest scoring game? as it has already been tested
Skipping original_prompt - ground_truth_data.json - How many players are from Duke University? as it has already been tested
Skipping original_prompt - ground_truth_data.json - Which team has the most home games? as it has already been tested
Skipping original_prompt - ground_truth_data.json - What's the average points per game? as it has already been tested
Skipping original_prompt - ground_truth_data.json - L

In [191]:
solution_framework.save_results()
solution_framework.print_accuracy()

Test results saved to data/output/solution_results_results.csv

Accuracy by Prompt Method:
     prompt_method accuracy
0  solution_prompt   41.09%


Accuracy by Input File:

ground_truth_data Accuracy by Prompt Method:
     prompt_method accuracy
0  solution_prompt   46.94%

synthetic_data Accuracy by Prompt Method:
     prompt_method accuracy
0  solution_prompt    37.5%
