## Amazon Bedrock LLM Router Evaluation

### Intro and Goal
This Jupyter Notebook is designed to test an LLM (Large Language Model) routing system on a Text-to-SQL use case.

The goal is to take a prompt, determine the level of complexity and then route the prompt either to a small or large LLM to generate the corresponding SQL query.

### Steps
1. **Create a ground truth dataset comprised of questions and SQL queries for a given database (e.g. Northwind) (notebook 4a)**
2. Define Router (notebook 4b)
3. Evaluate accuracy, cost, and latency of LLM classifier router approach compared to the baseline of using a larger LLM for all queries. (notebook 4b)

In [1]:
# 1. Create a conda environment

# !conda create -y --name bedrock-router-eval python=3.11.8
# !conda init && activate bedrock-router-eval
# !conda install -n bedrock-router-eval ipykernel --update-deps --force-reinstall -y
# !conda install -c conda-forge ipython-sql

In [2]:
# 2. Install dependencies

# !pip install -r requirements.txt

### Set Environment Variables

In [10]:
# 3. Import necessary libraries and load environment variables
import numpy as np
from scipy.spatial.distance import cdist
import json
from dotenv import load_dotenv, find_dotenv
import os
import boto3
import sqlite3
from pandas.io import sql
from botocore.config import Config
import pandas as pd
import io
import json
from io import StringIO
import sqlparse
import sqlite3
import time
import matplotlib.pyplot as plt
import re
import typing as t
from queue import Queue
from threading import Thread
from concurrent.futures import ThreadPoolExecutor, as_completed

# loading environment variables that are stored in local file
local_env_filename = 'bedrock-router-eval.env'
load_dotenv(find_dotenv(local_env_filename),override=True)

os.environ['REGION'] = os.getenv('REGION')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
os.environ['SQL_DATABASE'] = os.getenv('SQL_DATABASE') # LOCAL or GLUE
os.environ['SQL_DIALECT'] = os.getenv('SQL_DIALECT') # SQlite or awsathena
os.environ['SQL_DATABASE_NAME'] = os.getenv('SQL_DATABASE_NAME')
# os.environ['AWS_ACCESS_KEY'] = os.getenv('AWS_ACCESS_KEY')
# os.environ['AWS_SECRET_ACCESS_KEY'] = os.getenv('AWS_SECRET_ACCESS_KEY')

REGION = os.environ['REGION']
HF_TOKEN = os.environ['HF_TOKEN']
SQL_DATABASE = os.environ['SQL_DATABASE']
SQL_DIALECT = os.environ['SQL_DIALECT']
SQL_DATABASE_NAME = os.environ['SQL_DATABASE_NAME']

MODEL_ID = "anthropic.claude-3-haiku-20240307-v1:0" # "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

In [11]:
# 4. Definition of helper classes

class Util():
    def __init__(self,
        debug: bool = False

    ):
        self.debug = debug
    
    SCORE_PATTERN = r'<score>(.*?)</score>'
    REASONING_PATTERN = r'<thinking>(.*?)</thinking>'
    SQL_PATTERN = r'<SQL>(.*?)</SQL>'
    DIFFICULTY_PATTERN = r'<difficulty>(.*?)</difficulty>'


    def compare_results(self, answer_results1, answer_results2):


        # # Function to convert 'score' column
        def convert_score(df):
            # df['score'] = df['score'].map({'correct': 1, 'incorrect': 0})
            df['score'] = pd.to_numeric(df['score'], errors='coerce').fillna(0).astype(int)
            return df

        # Apply the conversion to both dataframes
        answer_results1 = convert_score(answer_results1)
        answer_results2 = convert_score(answer_results2)

        # Calculate the average values for each metric
        metrics = ['score', 'latency' ,'cost', 'ex_score', 'em_score','ves_score']
        
        avg_results1 = [answer_results1[metric].mean() for metric in metrics]
        avg_results2 = [answer_results2[metric].mean() for metric in metrics]

        # Calculate percentage change, handling divide-by-zero and infinite cases
        def safe_percent_change(a, b):
            if pd.isna(a) or pd.isna(b):
                return 0
            if a == 0 and b == 0:
                return 0
            elif a == 0:
                return 100  # Arbitrarily set to 100% increase if original value was 0
            else:
                change = (b - a) / a * 100
                return change if np.isfinite(change) else 0

        percent_change = [safe_percent_change(a, b) for a, b in zip(avg_results1, avg_results2)]

        # Set up the bar chart
        x = np.arange(len(metrics))
        width = 0.5

        fig, ax = plt.subplots(figsize=(12, 6))

        # Create the bars
        bars = ax.bar(x, percent_change, width)

        # Customize the chart
        ax.set_ylabel('Percentage Change (%)')
        ax.set_title('Percentage Change in Metrics (Results 2 vs Results 1)')
        ax.set_xticks(x)
        ax.set_xticklabels(metrics)

        # Add a horizontal line at y=0
        ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

        # Add value labels on top of each bar
        def autolabel(rects):
            for rect in rects:
                height = rect.get_height()
                ax.annotate(f'{height:.2f}%',
                            xy=(rect.get_x() + rect.get_width() / 2, height),
                            xytext=(0, 3 if height >= 0 else -3),  # 3 points vertical offset
                            textcoords="offset points",
                            ha='center', va='bottom' if height >= 0 else 'top')

        autolabel(bars)

        # Color the bars based on positive (green) or negative (red) change
        # For latency & cost, reverse the color logic
        for bar, change, metric in zip(bars, percent_change, metrics):
            if metric == 'latency' or metric == 'cost':
                bar.set_color('green' if change <= 0 else 'red')
            else:
                bar.set_color('green' if change >= 0 else 'red')
            

        # Adjust layout and display the chart
        fig.tight_layout()
        plt.show()

    def visualize_distribution(self, df, key):
        # Check if 'score' column exists in the DataFrame
        if key not in df.columns:
            raise ValueError(f"The DataFrame does not contain a '{key}' column.")
        
        # Count the frequency of each score
        score_counts = df[key].value_counts().sort_index()
        
        # Create a bar chart
        plt.figure(figsize=(10, 6))
        plt.bar(score_counts.index, score_counts.values)
        
        # Customize the chart
        plt.title(f'Distribution of {key}')
        plt.xlabel(f'{key}')
        plt.ylabel('Frequency')
        plt.xticks(range(int(score_counts.index.min()), int(score_counts.index.max()) + 1))
        
        # Add value labels on top of each bar
        for i, v in enumerate(score_counts.values):
            plt.text(score_counts.index[i], v, str(v), ha='center', va='bottom')
        
        # Display the chart
        plt.tight_layout()
        plt.show()

    # Strip out the portion of the response with regex.
    def extract_with_regex(self, response, regex):
        matches = re.search(regex, response, re.DOTALL)
        # Extract the matched content, if any
        return matches.group(1).strip() if matches else None

    def calculate_cost(self, usage, model_id):
        '''
        Takes the usage tokens returned by Bedrock in input and output, and coverts to cost in dollars.
        '''
        
        input_token_haiku = 0.25/1000000
        output_token_haiku = 1.25/1000000
        input_token_sonnet = 3.00/1000000
        output_token_sonnet = 15.00/1000000
        input_token_opus = 15.00/1000000
        output_token_opus = 75.00/1000000
        
        input_token_titan_embeddingv1 = 0.1/1000000
        input_token_titan_embeddingv2 = 0.02/1000000
        input_token_titan_embeddingmultimodal = 0.8/1000000
        input_token_titan_premier = 0.5/1000000
        output_token_titan_premier = 1.5/1000000
        input_token_titan_lite = 0.15/1000000
        output_token_titan_lite = 0.2/1000000
        input_token_titan_express = 0.2/1000000
        output_token_titan_express = 0.6/1000000
       
        input_token_cohere_command = 0.15/1000000
        output_token_cohere_command = 2/1000000
        input_token_cohere_commandlight = 0.3/1000000
        output_token_cohere_commandlight = 0.6/1000000
        input_token_cohere_commandrplus = 3/1000000
        output_token_cohere_commandrplus = 15/1000000
        input_token_cohere_commandr = 5/1000000
        output_token_cohere_commandr = 1.5/1000000
        input_token_cohere_embedenglish = 0.1/1000000
        input_token_cohere_embedmultilang = 0.1/1000000

        input_token_llama3_8b = 0.4/1000000
        output_token_llama3_8b = 0.6/1000000
        input_token_llama3_70b = 2.6/1000000
        output_token_llama3_70b = 3.5/1000000

        input_token_mistral_8b = 0.15/1000000
        output_token_mistral_8b = 0.2/1000000
        input_token_mistral_large = 4/1000000
        output_token_mistral_large = 12/1000000

        cost = 0

        if 'haiku' in model_id:
            cost+= usage['inputTokens']*input_token_haiku
            cost+= usage['outputTokens']*output_token_haiku
        if 'sonnet' in model_id:
            cost+= usage['inputTokens']*input_token_sonnet
            cost+= usage['outputTokens']*output_token_sonnet
        if 'opus' in model_id:
            cost+= usage['inputTokens']*input_token_opus
            cost+= usage['outputTokens']*output_token_opus
        if 'amazon.titan-embed-text-v1' in model_id:
            cost+= usage['inputTokens']*input_token_titan_embeddingv1
        if 'amazon.titan-embed-text-v2' in model_id:
            cost+= usage['inputTokens']*input_token_titan_embeddingv2
        if 'cohere.embed-multilingual' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_embedmultilang
        if 'cohere.embed-english' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_embedenglish 
        if 'meta.llama3-8b-instruct' in model_id:
            cost+= usage['inputTokens']*input_token_llama3_8b
            cost+= usage['outputTokens']*output_token_llama3_8b
        if 'meta.llama3-70b-instruct' in model_id:
            cost+= usage['inputTokens']*input_token_llama3_70b
            cost+= usage['outputTokens']*output_token_llama3_70b
        if 'cohere.command-text' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_command
            cost+= usage['outputTokens']*output_token_cohere_command
        if 'cohere.command-light-text' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_commandlight
            cost+= usage['outputTokens']*output_token_cohere_commandlight
        if 'cohere.command-r-plus' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_commandrplus
            cost+= usage['outputTokens']*output_token_cohere_commandrplus
        if 'cohere.command-r' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_commandr
            cost+= usage['outputTokens']*output_token_cohere_commandr
        if 'amazon.titan-text-express' in model_id:
            cost+= usage['inputTokens']*input_token_titan_express
            cost+= usage['outputTokens']*output_token_titan_express
        if 'amazon.titan-text-lite' in model_id:
            cost+= usage['inputTokens']*input_token_titan_lite
            cost+= usage['outputTokens']*output_token_titan_lite
        if 'amazon.titan-text-premier' in model_id:
            cost+= usage['inputTokens']*input_token_titan_premier
            cost+= usage['outputTokens']*output_token_titan_premier
        if 'mistral.mixtral-8x7b-instruct-v0:1' in model_id:
            cost+= usage['inputTokens']*input_token_mistral_8b
            cost+= usage['outputTokens']*output_token_mistral_8b

        return cost


# Bedrock LLM Class
class BedrockLLMWrapper():
    def __init__(self,
        model_id: str = 'anthropic.claude-3-haiku-20240307-v1:0', #'anthropic.claude-3-sonnet-20240229-v1:0',
        top_k: int = 5,
        top_p: int = 0.7,
        temperature: float = 0.0,
        max_token_count: int = 4000,
        max_attempts: int = 3,
        debug: bool = False

    ):

        self.model_id = model_id
        self.top_k = top_k
        self.top_p = top_p
        self.temperature = temperature
        self.max_token_count = max_token_count
        self.max_attempts = max_attempts
        self.debug = debug
        config = Config(
            retries = {
                'max_attempts': 10,
                'mode': 'standard'
            }
        )

        self.bedrock_runtime = boto3.client(service_name="bedrock-runtime", config=config, region_name=REGION)
        
    def generate(self,prompt):
        if self.debug: 
            print('entered BedrockLLMWrapper generate')
        attempt = 1

        message = {
            "role": "user",
            "content": [{"text": prompt}]
        }
        messages = []
        messages.append(message)
        
        # model specific inference parameters to use.
        if "anthropic" in self.model_id.lower():
            # system_prompts = [{"text": "You are a helpful AI Assistant."}]
            system_prompts = []
            # Base inference parameters to use.
            inference_config = {
                                "temperature": self.temperature, 
                                "maxTokens": self.max_token_count,
                                "stopSequences": ["\n\nHuman:"],
                                "topP": self.top_p,
                            }
            additional_model_fields = {"top_k": self.top_k}
        else:
            system_prompts = []
            # Base inference parameters to use.
            inference_config = {
                                "temperature": self.temperature, 
                                "maxTokens": self.max_token_count,
                            }
            additional_model_fields = {"top_k": self.top_k}

        if self.debug: 
            print("Sending:\nSystem:\n",system,"\nMessages:\n",str(messages))

        while True:
            try:

                # Send the message.
                response = self.bedrock_runtime.converse(
                    modelId=self.model_id,
                    messages=messages,
                    system=system_prompts,
                    inferenceConfig=inference_config,
                    additionalModelRequestFields=additional_model_fields
                )

                # Log token usage.
                text = response['output'].get('message').get('content')[0].get('text')
                usage = response['usage']
                latency = response['metrics'].get('latencyMs')

                if self.debug: 
                    print(f'text: {text} ; and token usage: {usage} ; and query_time: {latency}')    
                
                break
               
            except Exception as e:
                print("Error with calling Bedrock: "+str(e))
                attempt+=1
                if attempt>self.max_attempts:
                    print("Max attempts reached!")
                    result_text = str(e)
                    break
                else:#retry in 10 seconds
                    print("retry")
                    time.sleep(10)

        # return result_text
        return [text,usage,latency]

     # Threaded function for queue processing.
    def thread_request(self, q, results):
        while True:
            try:
                index, prompt = q.get(block=False)
                data = self.generate(prompt)
                results[index] = data
            except Queue.Empty:
                break
            except Exception as e:
                print(f'Error with prompt: {str(e)}')
                results[index] = str(e)
            finally:
                q.task_done()

    def generate_threaded(self, prompts, max_workers=15):
        results = [None] * len(prompts)
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_index = {executor.submit(self.generate, prompt): i for i, prompt in enumerate(prompts)}
            for future in as_completed(future_to_index):
                index = future_to_index[future]
                try:
                    results[index] = future.result()
                except Exception as exc:
                    print(f'Generated an exception: {exc}')
                    results[index] = str(exc)
        
        return results


In [5]:
# 5. Get schema for all tables in database

if SQL_DATABASE == 'LOCAL':
    # create local db and import northwind database

    import requests
    import sqlite3
    import re

    # Download the SQL files
    url1 = "https://raw.githubusercontent.com/YugaByte/yugabyte-db/master/sample/northwind_ddl.sql"
    url2 = "https://raw.githubusercontent.com/YugaByte/yugabyte-db/master/sample/northwind_data.sql"

    urls = [url1,url2]

    for url in urls:
        response = requests.get(url)
        sql_content = response.text

        # Create a SQLite database connection
        conn = sqlite3.connect('routedb.db')
        cursor = conn.cursor()

        # Split the SQL content into individual statements
        sql_statements = re.split(r';\s*$', sql_content, flags=re.MULTILINE)

        # Execute each SQL statement
        for statement in sql_statements:
            # Skip empty statements
            if statement.strip():
                # print(f'statement: {statement}')
                # Replace PostgreSQL-specific syntax with SQLite equivalents
                statement = statement.replace('SERIAL PRIMARY KEY', 'INTEGER PRIMARY KEY AUTOINCREMENT')
                statement = statement.replace('::int', '')
                statement = statement.replace('::varchar', '')
                statement = statement.replace('::real', '')
                statement = statement.replace('::date', '')
                statement = statement.replace('::boolean', '')
                statement = statement.replace('public.', '')
                statement = re.sub(r'WITH \(.*?\)', '', statement)
                
                try:
                    cursor.execute(statement)
                except sqlite3.Error as e:
                    print(f"Error executing statement: {e}")
                    print(f"Statement: {statement}")

        # Commit the changes and close the connection
        conn.commit()
        conn.close()

        print("SQL execution completed.")


        def get_schema_as_string(db_path):
            conn = sqlite3.connect(db_path)
            cursor = conn.cursor()

            # Query to get all table names
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
            tables = cursor.fetchall()

            schema_string = ""

            for table in tables:
                table_name = table[0]
                # Query to get the CREATE TABLE statement for each table
                cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}';")
                create_table_stmt = cursor.fetchone()[0]
                
                schema_string += f"{create_table_stmt};\n\n"

            conn.close()
            return schema_string
        
        schema = get_schema_as_string('routedb.db')
else: 
    # use a Glue database
    def get_schema(database_name, table_names=None):
        try:
            glue_client = boto3.client('glue', region_name=REGION)
            table_schema_list = []
            response = glue_client.get_tables(DatabaseName=database_name)

            all_table_names = [table['Name'] for table in response['TableList']]

            if table_names:
                table_names = [name for name in table_names if name in all_table_names]
            else:
                table_names = all_table_names

            for table_name in table_names:
                response = glue_client.get_table(DatabaseName=database_name, Name=table_name)
                columns = response['Table']['StorageDescriptor']['Columns']
                schema = {column['Name']: column['Type'] for column in columns}
                table_schema_list.append({"Table: {}".format(table_name): 'Schema: {}'.format(schema)})
        except Exception as e:
            print(f"Error: {str(e)}")
        return table_schema_list
    schema = get_schema(SQL_DATABASE_NAME)

Error executing statement: near "SET": syntax error
Statement: --
-- PostgreSQL database dump
--

SET statement_timeout = 0
Error executing statement: near "SET": syntax error
Statement: 
SET lock_timeout = 0
Error executing statement: near "SET": syntax error
Statement: 
SET client_encoding = 'UTF8'
Error executing statement: near "SET": syntax error
Statement: 
SET standard_conforming_strings = on
Error executing statement: near "SET": syntax error
Statement: 
SET check_function_bodies = false
Error executing statement: near "SET": syntax error
Statement: 
Error executing statement: near "SET": syntax error
Statement: 
SET default_tablespace = ''
Error executing statement: near "SET": syntax error
Statement: 
SET default_with_oids = false
SQL execution completed.
Error executing statement: near "SET": syntax error
Statement: --
-- PostgreSQL database dump
--

SET statement_timeout = 0
Error executing statement: near "SET": syntax error
Statement: 
SET lock_timeout = 0
Error executing

In [6]:
# 6. Generate questions and SQL queries based on database schema
## re-run cell to increase # of question query pairs which will function as groundtruth

import pandas as pd
import json

def generate_question_query_dataset(existing_questions):
    prompt = """Human: Review the provided database schema below. 
            Then create 100 questions in natural language along with corresponding SQL queries that would answer these questions based on this database schema.
            
            <database_schema>
            {database_schema}
            </database_schema>

            Ensure that the generated question is not already part of the existing data below.

            <existing_questions>
            {existing_questions}
            </existing_questions>
            
            Return the response in JSONL and return only the JSON and nothing else.      
            Assistant: {{""".format(database_schema=schema, existing_questions=existing_questions)

    MODEL_ID = 'anthropic.claude-3-5-sonnet-20240620-v1:0'
    wrapper = BedrockLLMWrapper(debug=False, model_id=MODEL_ID, max_token_count=3000)
    result = wrapper.generate(prompt)
    
    return result[0]

file_path = 'data/eval-source/question_query.jsonl'
if os.path.exists(file_path):
    with open(file_path, 'r') as file:
        existing_questions = file.read()
else:
    existing_questions = ''

response_text = generate_question_query_dataset(existing_questions)

# append response_text to existing_questions
response_text = existing_questions + response_text

# read jsonl_string into dataframe

def parse_json_line(line):
    try:
        return json.loads(line)
    except json.JSONDecodeError:
        print(f"Error parsing line: {line}")
        return None

if 'response_text' not in globals():
    response_text = []
    with open(file_path, 'r') as file:
        response_text = file.read()

# Split the string into lines and parse each line
data = [parse_json_line(line) for line in response_text.strip().split('\n')]

# Remove any None values (failed parses)
data = [d for d in data if d is not None]

# Create a DataFrame from the list of dictionaries
df = pd.DataFrame(data)

df.to_json(file_path, orient='records', lines=True)

print(f"Number of successfully parsed questions: {len(df)}")
print(df.head(5))


Number of successfully parsed questions: 172
                                            question  \
0             What is the total number of customers?   
1      List all product names and their unit prices.   
2        Who are the top 5 customers by order count?   
3       What is the average freight cost for orders?   
4  List all employees with their full names and t...   

                                               query  
0                    SELECT COUNT(*) FROM customers;  
1     SELECT product_name, unit_price FROM products;  
2  SELECT c.customer_id, c.company_name, COUNT(o....  
3                   SELECT AVG(freight) FROM orders;  
4  SELECT employee_id, first_name || ' ' || last_...  


In [7]:
# 7. Test generated SQL queries and remove those question query pairs that do not run successfully.
results = []

# Check if df exists in the current namespace
if 'df' not in globals():
    # If it doesn't exist, try to load it from a JSONL file
    if os.path.exists(file_path):
        # Load the dataframe from the JSONL file
        df = pd.read_json(file_path, lines=True)
        print("df loaded from JSONL file.")
    else:
        print(f"Error: JSONL file not found at {file_path}")
else:
    print(f"df with column names: {df.columns} already exists in memory.")


df.columns = df.columns.str.capitalize()
if SQL_DATABASE == 'LOCAL':
    # Create a SQLite database connection
    conn = sqlite3.connect('routedb.db')
    cursor = conn.cursor()

for row in df.itertuples():
    # print(row.query)
    error = None
    result = None
    try:
        
        if SQL_DATABASE == 'LOCAL':
            # Use local SQL lite
            statement = row.Query
            # Replace PostgreSQL-specific syntax with SQLite equivalents
            
            try:
                cursor.execute(statement)
                # Fetch all rows from the result
                result = cursor.fetchall()

            except sqlite3.Error as e:
                print(f"Error executing statement: {e}")
                error = e
                # print(f"Statement: {statement}")

        else:
            # Use Athena if AWS Glue Schema is used
            result = execute_athena_query(DATABASE, row.Query)
        
    except ClientError as e:
        error = e

    results.append({'Question': row.Question,'Query': row.Query, 'Result': result, 'Error': error, 'Context': schema})

if SQL_DATABASE == 'LOCAL':
    # close the connection
    conn.close()

df_results = pd.DataFrame(results)

# Use all generated prompts that resulted in valid SQL queries and filter out the rest

df_good_results = df_results[df_results['Error'].isnull() | (df_results['Error'] == None)]
print(f"Number of successful queries: {len(df_good_results)}")

df_bad_results = df_results[df_results['Error'].notnull() | (df_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df_bad_results)}")

# safe good queries as jsonl as our golden dataset
good_results_file_path = 'data/eval-source/question_query_good_results.jsonl'
df_good_results.to_json(good_results_file_path, orient='records', lines=True)


df with column names: Index(['question', 'query'], dtype='object') already exists in memory.
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: no such column: INTERVAL
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error e

In [8]:
# sql_dialect = awsathena or SQLite
def build_sqlquerygen_prompt(user_question: str, sql_database_schema: str):
    prompt = """You are a SQL expert. You will be provided with the original user question and a SQL database schema. 
                Only return the SQL query and nothing else.
                Here is the original user question.
                <user_question>
                {user_question}
                </user_question>

                Here is the SQL database schema.
                <sql_database_schema>
                {sql_database_schema}
                </sql_database_schema>
                
                Instructions:
                Generate a SQL query that answers the original user question.
                Use the schema, first create a syntactically correct {sql_dialect} query to answer the question. 
                Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
                Pay attention to use only the column names that you can see in the schema description. 
                Be careful to not query for columns that do not exist. 
                Pay attention to which column is in which table. 
                Also, qualify column names with the table name when needed.
                If you cannot answer the user question with the help of the provided SQL database schema, 
                then output that this question question cannot be answered based of the information stored in the database.
                You are required to use the following format, each taking one line:
                Return the sql query inside the <SQL></SQL> tab.
                """.format(
                    user_question=user_question,
                    sql_database_schema=sql_database_schema,
                    sql_dialect=SQL_DIALECT
                ) 
    return prompt

In [12]:
# 8a. Use this golden dataset to run test with smaller LLM

MODEL_ID = "mistral.mixtral-8x7b-instruct-v0:1" # "anthropic.claude-3-haiku-20240307-v1:0" # "mistral.mixtral-8x7b-instruct-v0:1" "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

# use helper class for threaded API calls
wrapper = BedrockLLMWrapper(debug=False, model_id=MODEL_ID, max_token_count=500)
util = Util()
df1 = df_good_results
prompts_list = []
for row in df1.itertuples():
    prompt = build_sqlquerygen_prompt(row.Question, row.Context)
    prompts_list.append(prompt)
results = wrapper.generate_threaded(prompts_list)

# Create a list to store the generated SQL queries
generated_sql_queries = []
for result in results:
    generated_sql_query = result[0].replace("\\","") # workaround, switching to ConverseAPI introduced \ in Mistral response
    print(f'generated_sql_query: {generated_sql_query}')
    generated_sql_queries.append(generated_sql_query)

# Add the new column 'Generated_SQL_Query' to df_results
df1['Generated_SQL_Query'] = generated_sql_queries


# Test generated SQL queries and verify they work
results = []

if SQL_DATABASE == 'LOCAL':
    # Create a SQLite database connection
    conn = sqlite3.connect('routedb.db')
    cursor = conn.cursor()

for row in df1.itertuples():
    statement = util.extract_with_regex(row.Generated_SQL_Query, util.SQL_PATTERN)
    # print(f'SQL statement: {statement}')
    error = None
    result = None
    try:
        
        if SQL_DATABASE == 'LOCAL':
            # Use local SQL lite
            
            # Replace PostgreSQL-specific syntax with SQLite equivalents
            
            try:
                cursor.execute(statement)
                # Fetch all rows from the result
                result = cursor.fetchall()
                # print(result)

            except sqlite3.Error as e:
                print(f"Error executing statement: {e}")
                error = e
                # print(f"Statement: {statement}")

        else:
            # Use Athena if AWS Glue Schema is used
            result = execute_athena_query(SQL_DATABASE_NAME, row.Generated_SQL_Query)
        
    except ClientError as e:
        error = e

    results.append({'Question': row.Question,'Query': statement, 'Result': result, 'Error': error, 'ReferenceQuery': row.Query, 'Context': row.Context})

if SQL_DATABASE == 'LOCAL':
    # close the connection
    conn.close()

# inspect first 3 results
df1_results = pd.DataFrame(results)
print(df1_results.head(3))

# review successful/unsucessful queries
df1_good_results = df1_results[df1_results['Error'].isnull() | (df1_results['Error'] == None)]
print(f"Number of successful queries: {len(df1_good_results)}")

df1_bad_results = df1_results[df1_results['Error'].notnull() | (df1_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df1_bad_results)}")

generated_sql_query:  <SQL>
SELECT COUNT(*)
FROM customers;
</SQL>
generated_sql_query:  <SQL>
SELECT products.product_name, products.unit_price
FROM products;
</SQL>
generated_sql_query:  <SQL>
SELECT customers.customer_id, COUNT(orders.order_id) as order_count
FROM customers
JOIN orders ON customers.customer_id = orders.customer_id
GROUP BY customers.customer_id
ORDER BY order_count DESC
LIMIT 5;
</SQL>
generated_sql_query:  <SQL>
SELECT AVG(freight)
FROM orders;
</SQL>
generated_sql_query:  <SQL>
SELECT employees.first_name || ' ' || employees.last_name AS full_name, employees.title
FROM employees;
</SQL>
generated_sql_query:  <SQL>
SELECT categories.category_name, COUNT(products.product_id) AS product_count
FROM categories
LEFT JOIN products ON categories.category_id = products.category_id
GROUP BY categories.category_name;
</SQL>
generated_sql_query:  <SQL>
SELECT SUM(p.units_in_stock * p.unit_price) as total_value
FROM products p;
</SQL>
generated_sql_query:  <SQL>
SELECT company

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df1['Generated_SQL_Query'] = generated_sql_queries


Error executing statement: no such table: regions
Error executing statement: no such table: order details
Error executing statement: no such function: DATEDIFF
Error executing statement: ambiguous column name: unit_price
Error executing statement: no such column: od.employee_id
Error executing statement: no such column: shipped_country
Error executing statement: no such column: orders.region
Error executing statement: no such column: c.ship_country
Error executing statement: no such function: DATEDIFF
Error executing statement: no such column: orders.region
Error executing statement: no such function: DATEDIFF
Error executing statement: no such function: DATEDIFF
Error executing statement: no such column: orders.product_id
                                        Question  \
0         What is the total number of customers?   
1  List all product names and their unit prices.   
2    Who are the top 5 customers by order count?   

                                               Query  \
0 

In [13]:
# 8b. Use golden dataset to run test with larger LLM

MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0" # "anthropic.claude-3-haiku-20240307-v1:0" # "mistral.mixtral-8x7b-instruct-v0:1" "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

# use helper class for threaded API calls
wrapper = BedrockLLMWrapper(debug=False, model_id=MODEL_ID, max_token_count=500)
util = Util()
df2 = df_good_results
prompts_list = []
for row in df1.itertuples():
    prompt = build_sqlquerygen_prompt(row.Question, row.Context)
    prompts_list.append(prompt)
results = wrapper.generate_threaded(prompts_list)

# Create a list to store the generated SQL queries
generated_sql_queries = []
for result in results:
    generated_sql_query = result[0].replace("\\","") # workaround, switching to ConverseAPI introduced \ in Mistral response
    print(f'generated_sql_query: {generated_sql_query}')
    generated_sql_queries.append(generated_sql_query)

# Add the new column 'Generated_SQL_Query' to df_results
df2['Generated_SQL_Query'] = generated_sql_queries

# Test generated SQL queries and verify they work
results = []

if SQL_DATABASE == 'LOCAL':
    # Create a SQLite database connection
    conn = sqlite3.connect('routedb.db')
    cursor = conn.cursor()

for row in df2.itertuples():
    statement = util.extract_with_regex(row.Generated_SQL_Query, util.SQL_PATTERN)
    # print(f'SQL statement: {statement}')
    error = None
    try:
        
        if SQL_DATABASE == 'LOCAL':
            # Use local SQL lite
            
            # Replace PostgreSQL-specific syntax with SQLite equivalents
            
            try:
                cursor.execute(statement)
                # Fetch all rows from the result
                result = cursor.fetchall()
                # print(result)

            except sqlite3.Error as e:
                print(f"Error executing statement: {e}")
                error = e
                # print(f"Statement: {statement}")

        else:
            # Use Athena if AWS Glue Schema is used
            result = execute_athena_query(SQL_DATABASE_NAME, row.Generated_SQL_Query)
        
    except ClientError as e:
        error = e

    results.append({'Question': row.Question,'Query': statement, 'Result': result, 'Error': error, 'ReferenceQuery': row.Query, 'Context': row.Context})

if SQL_DATABASE == 'LOCAL':
    # close the connection
    conn.close()

df2_results = pd.DataFrame(results)

# inspect first 3 results
print(df2_results.head(3))

# review successful/unsucessful queries
df2_good_results = df2_results[df2_results['Error'].isnull() | (df2_results['Error'] == None)]
print(f"Number of successful queries: {len(df2_good_results)}")

df2_bad_results = df2_results[df2_results['Error'].notnull() | (df2_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df2_bad_results)}")

generated_sql_query: <SQL>
SELECT COUNT(*) AS total_customers
FROM customers;
</SQL>
generated_sql_query: <SQL>
SELECT product_name, unit_price
FROM products;
</SQL>
generated_sql_query: <SQL>
SELECT c.company_name, COUNT(o.order_id) AS order_count
FROM customers c
JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.company_name
ORDER BY order_count DESC
LIMIT 5;
</SQL>
generated_sql_query: <SQL>
SELECT AVG(freight) AS average_freight_cost
FROM orders;
</SQL>
generated_sql_query: <SQL>
SELECT e.first_name || ' ' || e.last_name AS full_name, e.title
FROM employees e;
</SQL>
generated_sql_query: <SQL>
SELECT c.category_name, COUNT(p.product_id) AS product_count
FROM categories c
LEFT JOIN products p ON c.category_id = p.category_id
GROUP BY c.category_name;
</SQL>
generated_sql_query: <SQL>
SELECT SUM(products.unit_price * products.units_in_stock) AS total_inventory_value
FROM products;
</SQL>
generated_sql_query: <SQL>
SELECT company_name, contact_name, contact_title, address, cit

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df2['Generated_SQL_Query'] = generated_sql_queries


                                        Question  \
0         What is the total number of customers?   
1  List all product names and their unit prices.   
2    Who are the top 5 customers by order count?   

                                               Query  \
0  SELECT COUNT(*) AS total_customers\nFROM custo...   
1    SELECT product_name, unit_price\nFROM products;   
2  SELECT c.company_name, COUNT(o.order_id) AS or...   

                                              Result Error  \
0                                            [(91,)]  None   
1  [(Chai, 18.0), (Chang, 19.0), (Aniseed Syrup, ...  None   
2  [(Save-a-lot Markets, 31), (Ernst Handel, 30),...  None   

                                      ReferenceQuery  \
0                    SELECT COUNT(*) FROM customers;   
1     SELECT product_name, unit_price FROM products;   
2  SELECT c.customer_id, c.company_name, COUNT(o....   

                                             Context  
0  CREATE TABLE categories (\n    cat

As expected, we can observe that the larger LLM is able to produce valid SQL queries slightly more successfully.

### Conclusion
We generated groundtruth data and have shown that a larger LLM can create better SQL queries compared to a smaller LLM.