## Amazon Bedrock Router Evaluation

In this notebook we built and evaluate a finetuned-LLM router for a Text-to-SQL use case.

This demonstrates a simple application using public Northwind with Amazon Titan Embeddings, Amazon Bedrock and Streamlit for the front-end.

The example receives a user’s prompt, generates a SQL query using in-memory vector database and few-shot examples. We then run the query using SQLite database and display query results in the user interface.

For simplicity, we use the in-memory Chroma database to store and search for embeddings vectors. In a real-world scenario at scale, you will likely want to use a persistent data store like the vector engine for Amazon OpenSearch Serverless or the pgvector extension for PostgreSQL.

In [None]:
# 1. Create a conda environment

# !conda create -y --name bedrock-router-eval python=3.11.8
# !conda init && activate bedrock-router-eval
# !conda install -n bedrock-router-eval ipykernel --update-deps --force-reinstall -y
# !conda install -c conda-forge ipython-sql

In [3]:
# 2. Install dependencies

# !pip install -r requirements.txt

Collecting importlib_resources==6.4.0 (from -r requirements.txt (line 5))
  Downloading importlib_resources-6.4.0-py3-none-any.whl.metadata (3.9 kB)
Collecting datasets==2.20.0 (from -r requirements.txt (line 6))
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting filelock (from datasets==2.20.0->-r requirements.txt (line 6))
  Using cached filelock-3.15.4-py3-none-any.whl.metadata (2.9 kB)
Collecting pyarrow>=15.0.0 (from datasets==2.20.0->-r requirements.txt (line 6))
  Downloading pyarrow-17.0.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting pyarrow-hotfix (from datasets==2.20.0->-r requirements.txt (line 6))
  Using cached pyarrow_hotfix-0.6-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets==2.20.0->-r requirements.txt (line 6))
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets==2.20.0->-r requirements.txt (line 6))
  Using cached pandas-2.2.2-cp311-cp311-manylinux_2_

In [3]:
# 3. Import necessary libraries and load environment variables
import numpy as np
from scipy.spatial.distance import cdist
import json
from dotenv import load_dotenv, find_dotenv
import os
import boto3
import sqlite3
from pandas.io import sql
from botocore.config import Config

# loading environment variables that are stored in local file
local_env_filename = 'bedrock-router-eval.env'
load_dotenv(find_dotenv(local_env_filename),override=True)

os.environ['REGION'] = os.getenv('REGION')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
REGION = os.environ['REGION']
HF_TOKEN = os.environ['HF_TOKEN']

# Initialize Bedrock runtime
config = Config(
   retries = {
      'max_attempts': 10,
      'mode': 'standard'
   }
)
bedrock_runtime_client = boto3.client(
        service_name="bedrock-runtime",
        config=config,
        region_name=REGION)

bedrock_client = boto3.client(service_name='bedrock', region_name=REGION)
athena_client = boto3.client('athena')
glue_client = boto3.client('glue')
s3_client = boto3.client('s3')

model_id = "anthropic.claude-3-haiku-20240307-v1:0" # "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

SQL_DATABASE = 'LOCAL' #GLUE
SQL_DIALECT = 'SQlite' #awsathena

# %load_ext sql

In [107]:
# TO DO: 
# review  https://bird-bench.github.io/ or latest spyder dataset or any SQL dataset from HF

# https://huggingface.co/datasets/b-mc2/sql-create-context

# from datasets import load_dataset
# # %sql sqlite:///routedb.db
# # to load SQL dataset from starcoder
## ds = load_dataset("bigcode/starcoderdata", data_dir="sql", split="train", token=True)
# ds = load_dataset("b-mc2/sql-create-context", split="train", token=True)
# from datasets import load_dataset_builder
# ds_builder = load_dataset_builder("b-mc2/sql-create-context")
# ds_builder.info.description
# ds_builder.info.features


In [26]:
# 4. Define Helper functions

import boto3
import pandas as pd
import io
import json
from io import StringIO

def dataframe_to_s3_jsonl(df, bucket_name, prefix, filename):
    """
    Convert a pandas DataFrame to JSONL format and upload it to S3.

    Parameters:
    df (pandas.DataFrame): The DataFrame to be converted and uploaded.
    bucket_name (str): The name of the S3 bucket.
    prefix (str): The S3 prefix (folder path) where the file will be uploaded.
    filename (str): The name of the file to be created in S3.

    Returns:
    str: The S3 URI of the uploaded file.
    """
    # Convert DataFrame to JSONL
    jsonl_buffer = StringIO()
    for _, row in df.iterrows():
        json.dump(row.to_dict(), jsonl_buffer)
        jsonl_buffer.write('\n')
    jsonl_buffer.seek(0)

    # Upload the JSONL data to S3
    s3_key = f"{prefix.rstrip('/')}/{filename}"
    s3_client.put_object(
        Bucket=bucket_name,
        Key=s3_key,
        Body=jsonl_buffer.getvalue(),
        ContentType='application/json'
    )

    # Return the S3 URI of the uploaded file
    return f"s3://{bucket_name}/{s3_key}"


def download_and_parse_jsonl(bucket_name, object_key):
    """
    Downloads a JSONL file from an Amazon S3 bucket and parses it into a pandas DataFrame.

    Args:
        bucket_name (str): The name of the S3 bucket where the JSONL file is stored.
        object_key (str): The key (path) of the JSONL file in the S3 bucket.

    Returns:
        pandas.DataFrame: A DataFrame containing the data from the JSONL file.
    """
    

    # Download the JSONL file from S3
    response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
    jsonl_data = response['Body'].read().decode('utf-8')

    # Parse the JSONL data into a list of dictionaries
    data = [json.loads(line) for line in jsonl_data.strip().split('\n')]

    # Create a DataFrame from the list of dictionaries
    df = pd.DataFrame(data)

    return df

def check_job_status_and_wait(job_arn):
    # # check status
    # bedrock.get_model_invocation_job(jobIdentifier=jobArn)['status']

    # # list batch jobs
    # bedrock.list_model_invocation_jobs(
    #     maxResults=10,
    #     statusEquals="Failed",
    #     sortOrder="Descending"
    # )

    while True:
        job_status = bedrock_client.get_model_invocation_job(jobIdentifier=job_arn)['status']
        print(f"Job status: {job_status}")

        if job_status == 'COMPLETED':
            output_s3_uri = bedrock_client.get_model_invocation_job(jobIdentifier=job_arn)['outputDataConfig']['s3OutputDataConfig']['s3Uri']
            output_file_key = output_s3_uri.replace(f"s3://{output_bucket}/{output_prefix}", "")
            output_file_name = output_file_key.split("/")[-1]
            break
        elif job_status == 'FAILED':
            print("Job failed.")
            break
        else:
            time.sleep(60)  # Wait for 1 minute before checking again
    
    return output_s3_uri

def get_schema(database_name, table_names=None):
    try:
        
        table_schema_list = []
        response = glue_client.get_tables(DatabaseName=database_name)

        all_table_names = [table['Name'] for table in response['TableList']]

        if table_names:
            table_names = [name for name in table_names if name in all_table_names]
        else:
            table_names = all_table_names

        for table_name in table_names:
            response = glue_client.get_table(DatabaseName=database_name, Name=table_name)
            columns = response['Table']['StorageDescriptor']['Columns']
            schema = {column['Name']: column['Type'] for column in columns}
            table_schema_list.append({"Table: {}".format(table_name): 'Schema: {}'.format(schema)})
    except Exception as e:
        print(f"Error: {str(e)}")
    return table_schema_list

def execute_athena_query(database, query):
    # Start query execution
    response = athena_client.start_query_execution(
        QueryString=query,
        QueryExecutionContext={
            'Database': database
        },
        ResultConfiguration={
            'OutputLocation': outputLocation
        }
    )

    # Get query execution ID
    query_execution_id = response['QueryExecutionId']
    print(f"Query Execution ID: {query_execution_id}")

    # Wait for the query to complete
    response_wait = athena_client.get_query_execution(QueryExecutionId=query_execution_id)

    while response_wait['QueryExecution']['Status']['State'] in ['QUEUED', 'RUNNING']:
        print("Query is still running...")
        response_wait = athena_client.get_query_execution(QueryExecutionId=query_execution_id)

    print(f'response_wait {response_wait}')

    # Check if the query completed successfully
    if response_wait['QueryExecution']['Status']['State'] == 'SUCCEEDED':
        print("Query succeeded!")

        # Get query results
        query_results = athena_client.get_query_results(QueryExecutionId=query_execution_id)

        # Extract and return the result data
        code = 'SUCCEEDED'
        return code, extract_result_data(query_results)

    else:
        print("Query failed!")
        code = response_wait['QueryExecution']['Status']['State']
        message = response_wait['QueryExecution']['Status']['StateChangeReason']
    
        return code, message

def extract_result_data(query_results):
    #Return a cleaned response to the agent
    result_data = []

    # Extract column names
    column_info = query_results['ResultSet']['ResultSetMetadata']['ColumnInfo']
    column_names = [column['Name'] for column in column_info]

    # Extract data rows
    for row in query_results['ResultSet']['Rows']:
        data = [item['VarCharValue'] for item in row['Data']]
        result_data.append(dict(zip(column_names, data)))

    return result_data

# sql_dialect = awsathena or SQLite
def build_sqlquerygen_prompt(user_question: str, sql_database_schema: str):
    prompt = """You will be provided with the original user question and a SQL database schema. 
                Only return the SQL query and nothing else.
                Here is the original user question.
                <user_question>
                {user_question}
                </user_question>

                Here is the SQL database schema.
                <sql_database_schema>
                {sql_database_schema}
                </sql_database_schema>
                
                Instructions:
                Generate a SQL query that answers the original user question.
                Use the schema, first create a syntactically correct {sql_dialect} query to answer the question. 
                Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
                Pay attention to use only the column names that you can see in the schema description. 
                Be careful to not query for columns that do not exist. 
                Pay attention to which column is in which table. 
                Also, qualify column names with the table name when needed.
                If you cannot answer the user question with the help of the provided SQL database schema, 
                then output that this question question cannot be answered based of the information stored in the database.
                You are required to use the following format, each taking one line:
                Return the sql query inside the <SQL></SQL> tab.
                """.format(
                    user_question=user_question,
                    sql_database_schema=sql_database_schema,
                    sql_dialect=SQL_DIALECT
                ) 
    return prompt

# grading rubric
RUBRIC = '''
- **5**:The generated SQL matches the reference SQL query.
- **4**:The generated SQL is valid, and produces a result without any errors, and is performant.
- **3**:The generated SQL is valid, and produce a result without any errors.
- **2**:The generated SQL is valid, but produces one or more errors.
- **1**:The generated SQL is invalid.
'''

def build_prediction_prompt(user_question: str, sql_database_schema: str):
    prompt = """You will be provided with the original user question and a SQL database schema. 
                Based on the question provided below, predict the score an expert evaluator would give to an AI assistant's response, 
                considering its helpfulness, relevance, adherence to facts, depth, creativity, and detail. 
                
                Here is the original user question.
                <user_question>
                {user_question}
                </user_question>

                Here is the SQL database schema.
                <sql_database_schema>
                {sql_database_schema}
                </sql_database_schema>
                
                Instructions:
                Your prediction should infer the level of proficiency needed to address the question effectively. 
                Use a scale from 1 to 5, where a higher score indicates a higher anticipated quality of response. 
                
                Score criteria:
                <rubric>
                {rubric}
                </rubric>

                Provide your prediction inside <score></score> tags.
                """.format(
                    user_question=user_question,
                    sql_database_schema=sql_database_schema,
                    rubric=RUBRIC,
                    sql_dialect=SQL_DIALECT
                ) 
    return prompt


def build_grader_prompt(original_instruction: str, sql_query: str):
    prompt = """You will be provided with the original user prompt, context, and generated SQL query, which is trying to answer the initial user prompt,
                and a rubric that instructs you on what makes this SQL query correct or incorrect.

    Here is the original instruction for the SQL query.
    <original_instruction>
    {original_instruction}
    </original_instruction>

    Here is the generated SQL query based on these instructions.
    <sql_query>
    {sql_query}
    </sql_query>
    
    Here is the rubric on how to grade the generated notification.
    <rubric>
    {rubric}
    </rubric>
    
    An answer is correct if it entirely meets the rubric criteria, and is otherwise incorrect.
    First, think through whether the answer is correct or incorrect based on the rubric inside <thinking></thinking> tags. 
    Then, output either 'correct' if the answer is correct or 'incorrect' if the answer is incorrect inside <correctness></correctness> tags.
    Use a scale from 1 to 5, where a higher score indicates a higher anticipated quality of response, provide the score inside <score></score> tags.
   
    """.format(
        sql_query=sql_query,
        rubric=RUBRIC,
        original_instruction=original_instruction
    ) 
    return prompt


# Helper class for async calls to Amazon Bedrock

import typing as t
import time
from queue import Queue
from threading import Thread
import boto3

class BedrockLLMWrapper():
    def __init__(self,
        model_id: str = 'anthropic.claude-3-haiku-20240307-v1:0', #'anthropic.claude-3-sonnet-20240229-v1:0',
        top_k: int = 5,
        top_p: int = 0.7,
        temperature: float = 0.0,
        max_token_count: int = 4000,
        anthropic_version: str = "bedrock-2023-05-31" if "anthropic" in model_id.lower() else "",
        max_attempts: int = 3,
        debug: bool = False

    ):

        self.model_id = model_id
        self.top_k = top_k
        self.top_p = top_p
        self.temperature = temperature
        self.max_token_count = max_token_count
        self.anthropic_version = anthropic_version
        self.max_attempts = max_attempts
        self.debug = debug
        self.bedrock_runtime = boto3.client(service_name="bedrock-runtime", config=config, region_name=REGION)
        
    def generate(self,prompt):
        if self.debug: 
            print('entered BedrockLLMWrapper generate')
        attempt = 1
        query_time = -1
        usage = (-1,-1)
        start_time = time.time()

        messages = prompt

        system = ''

        #if the messages are just a string, convert to the Messages API format.
        if type(messages)==str:
            messages = [{"role": "user", "content": messages}]
        
        #build the JSON to send to Bedrock
        if "anthropic" in self.model_id.lower():
            prompt_json = {
                "system":system,
                "messages": messages,
                "max_tokens": self.max_token_count, # 4096 is a hard limit to output length in Claude 3
                "temperature": self.temperature, #creativity on a scale from 0-1.
                "anthropic_version":self.anthropic_version,
                "top_k": self.top_k,
                "top_p": self.top_p,
                "stop_sequences": ["\n\nHuman:"]
            }
        else:
            
            prompt_json = {
                "prompt": prompt,
                "max_tokens": self.max_token_count,
                "temperature": 0,
            }

        if self.debug: 
            print("Sending:\nSystem:\n",system,"\nMessages:\n",str(messages))

        while True:
            try:
                response = self.bedrock_runtime.invoke_model(body=json.dumps(prompt_json), modelId=self.model_id, accept='application/json', contentType='application/json')
                response_body = json.loads(response.get('body').read())
                if "anthropic" in self.model_id.lower():
                    result_text = response_body.get("content")[0].get("text")
                else:
                    result_text = response_body.get("outputs")[0].get("text")
                    
                usage = response_body.get("usage")
                query_time = round(time.time()-start_time,2)
                if self.debug:
                    print("Retrieved:",result_text)
                
                break
                
            except Exception as e:
                print("Error with calling Bedrock: "+str(e))
                attempt+=1
                if attempt>self.max_attempts:
                    print("Max attempts reached!")
                    result_text = str(e)
                    break
                else:#retry in 10 seconds
                    print("retry")
                    time.sleep(10)
        if self.debug: 
            print(f'usage: {usage} and query_time: {query_time}')

        # return result_text
        return [result_text,usage,query_time]

    # Threaded function for queue processing.
    def thread_request(self, q, result):
        while not q.empty():
            work = q.get()    #fetch new work from the Queue
            try:
                data = self.generate(work[1])
                result[work[0]] = data  #Store data back at correct index
            except Exception as e:
                print('Error with prompt!',str(e))
                result[work[0]] = (str(e))
            #signal to the queue that task has been processed
            q.task_done()
        return True

    def generate_threaded(self,prompts):
        '''
        Call multi-threaded.
        Returns a dict of the prompts and responses.
        '''
        system=""
        ignore_cache=False
        q = Queue(maxsize=0)
        num_theads = min(50, len(prompts))
        #Populating Queue with tasks
        results = [{} for x in prompts];
        #load up the queue with the promts to fetch and the index for each job (as a tuple):
        for i in range(len(prompts)):
            #need the index and the url in each queue item.
            q.put((i,prompts[i]))
            
        #Starting worker threads on queue processing
        for i in range(num_theads):
            if self.debug:
                print('Starting thread ', i)
            worker = Thread(target=self.thread_request, args=(q,results))
            worker.daemon = True
            worker.start()

        #now we wait until the queue has been processed
        q.join()
        return results
    
    def calculate_cost(self, usage, model_id):
        '''
        Takes the usage tokens returned by Bedrock in input and output, and coverts to cost in dollars.
        '''
        
        input_token_haiku = 0.25/1000000
        output_token_haiku = 1.25/1000000
        input_token_sonnet = 3.00/1000000
        output_token_sonnet = 15.00/1000000
        input_token_opus = 15.00/1000000
        output_token_opus = 75.00/1000000
        
        input_token_titan_embeddingv1 = 0.1/1000000
        input_token_titan_embeddingv2 = 0.02/1000000
        input_token_titan_embeddingmultimodal = 0.8/1000000
        input_token_titan_premier = 0.5/1000000
        output_token_titan_premier = 1.5/1000000
        input_token_titan_lite = 0.15/1000000
        output_token_titan_lite = 0.2/1000000
        input_token_titan_express = 0.2/1000000
        output_token_titan_express = 0.6/1000000
       
        input_token_cohere_command = 0.15/1000000
        output_token_cohere_command = 2/1000000
        input_token_cohere_commandlight = 0.3/1000000
        output_token_cohere_commandlight = 0.6/1000000
        input_token_cohere_commandrplus = 3/1000000
        output_token_cohere_commandrplus = 15/1000000
        input_token_cohere_commandr = 5/1000000
        output_token_cohere_commandr = 1.5/1000000
        input_token_cohere_embedenglish = 0.1/1000000
        input_token_cohere_embedmultilang = 0.1/1000000

        input_token_llama3_8b = 0.4/1000000
        output_token_llama3_8b = 0.6/1000000
        input_token_llama3_70b = 2.6/1000000
        output_token_llama3_70b = 3.5/1000000

        cost = 0

        if 'haiku' in model_id:
            cost+= usage['input_tokens']*input_token_haiku
            cost+= usage['output_tokens']*output_token_haiku
        if 'sonnet' in model_id:
            cost+= usage['input_tokens']*input_token_sonnet
            cost+= usage['output_tokens']*output_token_sonnet
        if 'opus' in model_id:
            cost+= usage['input_tokens']*input_token_opus
            cost+= usage['output_tokens']*output_token_opus
        if 'amazon.titan-embed-text-v1' in model_id:
            cost+= usage['input_tokens']*input_token_titan_embeddingv1
        if 'amazon.titan-embed-text-v2' in model_id:
            cost+= usage['input_tokens']*input_token_titan_embeddingv2
        if 'cohere.embed-multilingual' in model_id:
            cost+= usage['input_tokens']*input_token_cohere_embedmultilang
        if 'cohere.embed-english' in model_id:
            cost+= usage['input_tokens']*input_token_cohere_embedenglish 
        if 'meta.llama3-8b-instruct' in model_id:
            cost+= usage['input_tokens']*input_token_llama3_8b
            cost+= usage['output_tokens']*output_token_llama3_8b
        if 'meta.llama3-70b-instruct' in model_id:
            cost+= usage['input_tokens']*input_token_llama3_70b
            cost+= usage['output_tokens']*output_token_llama3_70b
        if 'cohere.command-text' in model_id:
            cost+= usage['input_tokens']*input_token_cohere_command
            cost+= usage['output_tokens']*output_token_cohere_command
        if 'cohere.command-light-text' in model_id:
            cost+= usage['input_tokens']*input_token_cohere_commandlight
            cost+= usage['output_tokens']*output_token_cohere_commandlight
        if 'cohere.command-r-plus' in model_id:
            cost+= usage['input_tokens']*input_token_cohere_commandrplus
            cost+= usage['output_tokens']*output_token_cohere_commandrplus
        if 'cohere.command-r' in model_id:
            cost+= usage['input_tokens']*input_token_cohere_commandr
            cost+= usage['output_tokens']*output_token_cohere_commandr
        if 'amazon.titan-text-express' in model_id:
            cost+= usage['input_tokens']*input_token_titan_express
            cost+= usage['output_tokens']*output_token_titan_express
        if 'amazon.titan-text-lite' in model_id:
            cost+= usage['input_tokens']*input_token_titan_lite
            cost+= usage['output_tokens']*output_token_titan_lite
        if 'amazon.titan-text-premier' in model_id:
            cost+= usage['input_tokens']*input_token_titan_premier
            cost+= usage['output_tokens']*output_token_titan_premier

        return cost

def runBedrockBatchJob(modelid,df):
    BEDROCK_BATCH_API = False
    if BEDROCK_BATCH_API == True:
        ## WIP
        # upload df_results dataframe as jsonl to S3
        prefix = 'routeeval/input'
        filename = 'small_llm.jsonl'
        bucket_name = 'felixh-demo'

        s3_uri = dataframe_to_s3_jsonl(df_results, bucket_name, prefix, filename)
        print(f"File uploaded to: {s3_uri}")

        # generate SQL queries with Bedrock Batch API if filesize is >25mb otherwise use helper class for threaded calls

        input_bucket = bucket_name
        input_prefix = "input/"
        output_bucket = bucket_name
        output_prefix = "output/mistral8binstruct"
        BATCH_ROLE_ARN = 'arn:aws:iam::026459568683:role/admin'

        inputDataConfig=({
            "s3InputDataConfig": {
                "s3Uri": f"s3://{input_bucket}/{input_prefix}"
            }
        })

        outputDataConfig=({
            "s3OutputDataConfig": {
                "s3Uri": f"s3://{output_bucket}/{output_prefix}"
            }
        })

        batch_response=bedrock_client.create_model_invocation_job(
            roleArn=BATCH_ROLE_ARN,
            modelId=modelid,
            jobName="small-llm-genSQL-job",
            inputDataConfig=inputDataConfig,
            outputDataConfig=outputDataConfig
        )

        jobArn = batch_response.get('jobArn')


        # wait for batch job to complete
        output_s3_uri = check_job_status_and_wait(job_arn)
        print(output_s3_uri)
        output_file_name = output_s3_uri.split("/")[-1]
        print(output_file_name)

        # download results from S3 into dataframe
        df = download_and_parse_jsonl(bucket_name, object_key)

    else:
        # use helper class for threaded API calls
        wrapper = BedrockLLMWrapper(debug=False, model_id=modelid)

        prompts_list = []
        for row in df.itertuples():
            prompt = build_sqlquerygen_prompt(row.Question, row.Context)
            prompts_list.append(prompt)
        # [result_text,usage,query_time]
        results = wrapper.generate_threaded(prompts_list)

        # Create a list to store the generated SQL queries
        generated_sql_queries = []

        for result in results:
            generated_sql_query = result[0]
            generated_sql_queries.append(generated_sql_query)

        # Add the new column 'Generated_SQL_Query' to df_results
        df['Generated_SQL_Query'] = generated_sql_queries

        # for result in results:
        #     # print(f'result: {result}')
        #     generated_sql_query = result[0]

        #     if result[1] != None:
        #         cost = wrapper.calculate_cost(result[1], wrapper.model_id)
        #         generation_info = {"input_tokens": result['input_tokens'], "output_tokens": result[1]['output_tokens'], "cost": cost , "query_time": result[2] }
        #     else:
        #         # not all models support/return token count yet in Bedrock API
        #         generation_info = {"input_tokens": 'N/A', "output_tokens": 'N/A', "cost": 0 , "query_time": result[2] }
            
        #     # print(f'generation_info: {generation_info}')
        return df


import re
import json
SCORE_PATTERN = r'<score>(.*?)</score>'
REASONING_PATTERN = r'<thinking>(.*?)</thinking>'
CORRECTNESS_PATTERN = r'<correctness>(.*?)</correctness>'
SQL_PATTERN = r'<SQL>(.*?)</SQL>'

# Strip out the portion of the response with regex.
def extract_with_regex(response, regex):
    matches = re.search(regex, response, re.DOTALL)
    # Extract the matched content, if any
    return matches.group(1).strip() if matches else None

def format_results(grade: str, chat_conversation: list[dict]) -> dict:
    reasoning: str = extract_with_regex(grade, REASONING_PATTERN)
    correctness: str =  extract_with_regex(grade, CORRECTNESS_PATTERN)
    
    return {
        'chat_conversation': chat_conversation,
        'reasoning': reasoning,
        'correctness': correctness
    }

In [6]:
# 5. Get schema for all tables in database

if SQL_DATABASE == 'LOCAL':
    # create local db and import northwind database

    # %load_ext sql
    # %sql sqlite:///routedb.db

    import requests
    import sqlite3
    import re

    # Download the SQL files
    url1 = "https://raw.githubusercontent.com/YugaByte/yugabyte-db/master/sample/northwind_ddl.sql"
    url2 = "https://raw.githubusercontent.com/YugaByte/yugabyte-db/master/sample/northwind_data.sql"

    urls = [url2]

    for url in urls:
        response = requests.get(url)
        sql_content = response.text

        # Create a SQLite database connection
        conn = sqlite3.connect('routedb.db')
        cursor = conn.cursor()

        # Split the SQL content into individual statements
        sql_statements = re.split(r';\s*$', sql_content, flags=re.MULTILINE)

        # Execute each SQL statement
        for statement in sql_statements:
            # Skip empty statements
            if statement.strip():
                print(f'statement: {statement}')
                # Replace PostgreSQL-specific syntax with SQLite equivalents
                statement = statement.replace('SERIAL PRIMARY KEY', 'INTEGER PRIMARY KEY AUTOINCREMENT')
                statement = statement.replace('::int', '')
                statement = statement.replace('::varchar', '')
                statement = statement.replace('::real', '')
                statement = statement.replace('::date', '')
                statement = statement.replace('::boolean', '')
                statement = statement.replace('public.', '')
                statement = re.sub(r'WITH \(.*?\)', '', statement)
                
                try:
                    cursor.execute(statement)
                except sqlite3.Error as e:
                    print(f"Error executing statement: {e}")
                    print(f"Statement: {statement}")

        # Commit the changes and close the connection
        conn.commit()
        conn.close()

        print("SQL execution completed.")

    #hardcoded Northwind database schema
    schema = '''--
    -- PostgreSQL database dump
    --

    SET statement_timeout = 0;
    SET lock_timeout = 0;
    SET client_encoding = 'UTF8';
    SET standard_conforming_strings = on;
    SET check_function_bodies = false;
    SET client_min_messages = warning;



    SET default_tablespace = '';

    SET default_with_oids = false;


    ---
    --- drop tables
    ---


    DROP TABLE IF EXISTS customer_customer_demo;
    DROP TABLE IF EXISTS customer_demographics;
    DROP TABLE IF EXISTS employee_territories;
    DROP TABLE IF EXISTS order_details;
    DROP TABLE IF EXISTS orders;
    DROP TABLE IF EXISTS customers;
    DROP TABLE IF EXISTS products;
    DROP TABLE IF EXISTS shippers;
    DROP TABLE IF EXISTS suppliers;
    DROP TABLE IF EXISTS territories;
    DROP TABLE IF EXISTS us_states;
    DROP TABLE IF EXISTS categories;
    DROP TABLE IF EXISTS region;
    DROP TABLE IF EXISTS employees;

    --
    -- Name: categories; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE categories (
        category_id smallint NOT NULL PRIMARY KEY,
        category_name character varying(15) NOT NULL,
        description text,
        picture bytea
    );


    --
    -- Name: customer_demographics; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE customer_demographics (
        customer_type_id bpchar NOT NULL PRIMARY KEY,
        customer_desc text
    );


    --
    -- Name: customers; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE customers (
        customer_id bpchar NOT NULL PRIMARY KEY,
        company_name character varying(40) NOT NULL,
        contact_name character varying(30),
        contact_title character varying(30),
        address character varying(60),
        city character varying(15),
        region character varying(15),
        postal_code character varying(10),
        country character varying(15),
        phone character varying(24),
        fax character varying(24)
    );

    --
    -- Name: customer_customer_demo; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE customer_customer_demo (
        customer_id bpchar NOT NULL,
        customer_type_id bpchar NOT NULL,
        PRIMARY KEY (customer_id, customer_type_id),
        FOREIGN KEY (customer_type_id) REFERENCES customer_demographics,
        FOREIGN KEY (customer_id) REFERENCES customers
    );

    --
    -- Name: employees; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE employees (
        employee_id smallint NOT NULL PRIMARY KEY,
        last_name character varying(20) NOT NULL,
        first_name character varying(10) NOT NULL,
        title character varying(30),
        title_of_courtesy character varying(25),
        birth_date date,
        hire_date date,
        address character varying(60),
        city character varying(15),
        region character varying(15),
        postal_code character varying(10),
        country character varying(15),
        home_phone character varying(24),
        extension character varying(4),
        photo bytea,
        notes text,
        reports_to smallint,
        photo_path character varying(255),
        FOREIGN KEY (reports_to) REFERENCES employees
    );


    --
    -- Name: suppliers; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE suppliers (
        supplier_id smallint NOT NULL PRIMARY KEY,
        company_name character varying(40) NOT NULL,
        contact_name character varying(30),
        contact_title character varying(30),
        address character varying(60),
        city character varying(15),
        region character varying(15),
        postal_code character varying(10),
        country character varying(15),
        phone character varying(24),
        fax character varying(24),
        homepage text
    );


    --
    -- Name: products; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE products (
        product_id smallint NOT NULL PRIMARY KEY,
        product_name character varying(40) NOT NULL,
        supplier_id smallint,
        category_id smallint,
        quantity_per_unit character varying(20),
        unit_price real,
        units_in_stock smallint,
        units_on_order smallint,
        reorder_level smallint,
        discontinued integer NOT NULL,
        FOREIGN KEY (category_id) REFERENCES categories,
        FOREIGN KEY (supplier_id) REFERENCES suppliers
    );


    --
    -- Name: region; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE region (
        region_id smallint NOT NULL PRIMARY KEY,
        region_description bpchar NOT NULL
    );


    --
    -- Name: shippers; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE shippers (
        shipper_id smallint NOT NULL PRIMARY KEY,
        company_name character varying(40) NOT NULL,
        phone character varying(24)
    );


    --
    -- Name: orders; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE orders (
        order_id smallint NOT NULL PRIMARY KEY,
        customer_id bpchar,
        employee_id smallint,
        order_date date,
        required_date date,
        shipped_date date,
        ship_via smallint,
        freight real,
        ship_name character varying(40),
        ship_address character varying(60),
        ship_city character varying(15),
        ship_region character varying(15),
        ship_postal_code character varying(10),
        ship_country character varying(15),
        FOREIGN KEY (customer_id) REFERENCES customers,
        FOREIGN KEY (employee_id) REFERENCES employees,
        FOREIGN KEY (ship_via) REFERENCES shippers
    );


    --
    -- Name: territories; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE territories (
        territory_id character varying(20) NOT NULL PRIMARY KEY,
        territory_description bpchar NOT NULL,
        region_id smallint NOT NULL,
        FOREIGN KEY (region_id) REFERENCES region
    );


    --
    -- Name: employee_territories; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE employee_territories (
        employee_id smallint NOT NULL,
        territory_id character varying(20) NOT NULL,
        PRIMARY KEY (employee_id, territory_id),
        FOREIGN KEY (territory_id) REFERENCES territories,
        FOREIGN KEY (employee_id) REFERENCES employees
    );


    --
    -- Name: order_details; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE order_details (
        order_id smallint NOT NULL,
        product_id smallint NOT NULL,
        unit_price real NOT NULL,
        quantity smallint NOT NULL,
        discount real NOT NULL,
        PRIMARY KEY (order_id, product_id),
        FOREIGN KEY (product_id) REFERENCES products,
        FOREIGN KEY (order_id) REFERENCES orders
    );


    --
    -- Name: us_states; Type: TABLE; Schema: public; Owner: -; Tablespace: 
    --

    CREATE TABLE us_states (
        state_id smallint NOT NULL PRIMARY KEY,
        state_name character varying(100),
        state_abbr character varying(2),
        state_region character varying(50)
    );'''


else: 
    # use a Glue database
    DATABASE = ''
    schema = get_schema(DATABASE)

statement: --
-- PostgreSQL database dump
--

SET statement_timeout = 0
Error executing statement: near "SET": syntax error
Statement: --
-- PostgreSQL database dump
--

SET statement_timeout = 0
statement: 
SET lock_timeout = 0
Error executing statement: near "SET": syntax error
Statement: 
SET lock_timeout = 0
statement: 
SET client_encoding = 'UTF8'
Error executing statement: near "SET": syntax error
Statement: 
SET client_encoding = 'UTF8'
statement: 
SET standard_conforming_strings = on
Error executing statement: near "SET": syntax error
Statement: 
SET standard_conforming_strings = on
statement: 
SET check_function_bodies = false
Error executing statement: near "SET": syntax error
Statement: 
SET check_function_bodies = false
statement: 
Error executing statement: near "SET": syntax error
Statement: 
statement: 
SET default_tablespace = ''
Error executing statement: near "SET": syntax error
Statement: 
SET default_tablespace = ''
statement: 
SET default_with_oids = false
Error ex

In [7]:
# 6. Generate questions and SQL queries based on database schema with Sonnet 3.5
import json

prompt = """Human: Review the provided database schema below. Then create 100 questions in natural language along with corresponding SQL queries that would answer these questions based on this database schema.
        
        <database_schema>
        {database_schema}
        </database_schema>
        
        Return the response in JSONL and return only the JSON and nothing else.      
        Assistant: {{""".format(database_schema=schema)

MODEL_ID = 'anthropic.claude-3-5-sonnet-20240620-v1:0'
body = {
    "anthropic_version": "bedrock-2023-05-31",
    "max_tokens": 4000,
    "temperature": 0,
    "messages": [
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt}],
        }
    ],
}

response = bedrock_runtime_client.invoke_model(
    body=json.dumps(body),
    modelId=MODEL_ID,

)

response_body = json.loads(response.get('body').read()) # read the response

response_text = response_body.get('content')[0].get('text')

print(response_text)


{"question": "What are the names of all categories?", "query": "SELECT category_name FROM categories;"}
{"question": "How many products are there in each category?", "query": "SELECT category_name, COUNT(*) FROM products JOIN categories ON products.category_id = categories.category_id GROUP BY category_name;"}
{"question": "Who are the suppliers with the most products?", "query": "SELECT suppliers.company_name, COUNT(*) as product_count FROM products JOIN suppliers ON products.supplier_id = suppliers.supplier_id GROUP BY suppliers.company_name ORDER BY product_count DESC LIMIT 5;"}
{"question": "What is the average unit price of products by category?", "query": "SELECT categories.category_name, AVG(products.unit_price) as avg_price FROM products JOIN categories ON products.category_id = categories.category_id GROUP BY categories.category_name;"}
{"question": "Which employees have the most orders?", "query": "SELECT employees.first_name, employees.last_name, COUNT(*) as order_count FROM

In [8]:
# 7. read jsonl_string into dataframe
import pandas as pd
import json

def parse_json_line(line):
    try:
        return json.loads(line)
    except json.JSONDecodeError:
        print(f"Error parsing line: {line}")
        return None

# Split the string into lines and parse each line
data = [parse_json_line(line) for line in response_text.strip().split('\n')]

# Remove any None values (failed parses)
data = [d for d in data if d is not None]

# Create a DataFrame from the list of dictionaries
df = pd.DataFrame(data)

df.to_parquet('question_query.parquet')

print(f"Number of successfully parsed questions: {len(df)}")
print(df.head(5))

Error parsing line: {"question": "Which employees have the highest sales growth year over year?", "query": "WITH yearly_sales AS (SELECT employees.employee_id, employees.first_name, employees.last_name, EXTRACT(YEAR FROM orders.order_date) as year, SUM(order_details.unit_price * order_details.quantity * (1 - order_details.discount)) as total_sales FROM employees JOIN orders ON employees.employee_id = orders.employee_id JOIN order_details ON orders.order_id = order_details.order_id GROUP BY employees.employee_id, EXTRACT(YEAR FROM orders.order_date)) SELECT ys1.employee_id, ys1.first_name, ys1.last_name, ys1.year, (
Number of successfully parsed questions: 44
                                            question  \
0              What are the names of all categories?   
1      How many products are there in each category?   
2      Who are the suppliers with the most products?   
3  What is the average unit price of products by ...   
4              Which employees have the most orders? 

In [9]:
# 8. Test generated SQL queries and verify they work
results = []

if SQL_DATABASE == 'LOCAL':
    # Create a SQLite database connection
    conn = sqlite3.connect('routedb.db')
    cursor = conn.cursor()

for row in df.itertuples():
    # print(row.query)
    error = None
    try:
        
        if SQL_DATABASE == 'LOCAL':
            # Use local SQL lite
            statement = row.query
            # Replace PostgreSQL-specific syntax with SQLite equivalents
            
            try:
                cursor.execute(statement)
                # Fetch all rows from the result
                result = cursor.fetchall()

            except sqlite3.Error as e:
                print(f"Error executing statement: {e}")
                error = e
                # print(f"Statement: {statement}")

        else:
            # Use Athena if AWS Glue Schema is used
            result = execute_athena_query(DATABASE, row.query)
        
    except ClientError as e:
        error = e

    results.append({'Question': row.question,'Query': row.query, 'Result': result, 'Error': error, 'Context': schema})

if SQL_DATABASE == 'LOCAL':
    # close the connection
    conn.close()

df_results = pd.DataFrame(results)


Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error
Error executing statement: near "FROM": syntax error


In [10]:
# 9. Use all generated prompts that resulted in valid SQL queries and filter out the rest

df_good_results = df_results[df_results['Error'].isnull() | (df_results['Error'] == None)]
print(f"Number of successful queries: {len(df_good_results)}")

df_bad_results = df_results[df_results['Error'].notnull() | (df_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df_bad_results)}")



Number of successful queries: 40
Number of unsuccessful queries: 4


In [11]:
# 10. Use golden dataset to run test with smaller LLM

MODEL_ID = "mistral.mixtral-8x7b-instruct-v0:1" # "anthropic.claude-3-haiku-20240307-v1:0" # "mistral.mixtral-8x7b-instruct-v0:1" "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

df1 = runBedrockBatchJob(MODEL_ID,df_good_results)

# Test generated SQL queries and verify they work
results = []

if SQL_DATABASE == 'LOCAL':
    # Create a SQLite database connection
    conn = sqlite3.connect('routedb.db')
    cursor = conn.cursor()

for row in df1.itertuples():
    statement = extract_with_regex(row.Generated_SQL_Query, SQL_PATTERN)
    # print(f'SQL statement: {statement}')
    error = None
    try:
        
        if SQL_DATABASE == 'LOCAL':
            # Use local SQL lite
            
            # Replace PostgreSQL-specific syntax with SQLite equivalents
            
            try:
                cursor.execute(statement)
                # Fetch all rows from the result
                result = cursor.fetchall()
                # print(result)

            except sqlite3.Error as e:
                print(f"Error executing statement: {e}")
                error = e
                # print(f"Statement: {statement}")

        else:
            # Use Athena if AWS Glue Schema is used
            result = execute_athena_query(DATABASE, row.Generated_SQL_Query)
        
    except ClientError as e:
        error = e

    results.append({'Question': row.Question,'Query': statement, 'Result': result, 'Error': error, 'Context': row.Context})

if SQL_DATABASE == 'LOCAL':
    # close the connection
    conn.close()

df1_results = pd.DataFrame(results)
print(df1_results.head(3))

Error executing statement: no such column: shipper_id
Error executing statement: no such function: DATEDIFF
Error executing statement: no such column: country
Error executing statement: ambiguous column name: unit_price
Error executing statement: no such column: od.order_date
Error executing statement: no such column: category_name
Error executing statement: no such column: customer_id
Error executing statement: ambiguous column name: customer_type_id
Error executing statement: no such function: DATEDIFF
                                        Question  \
0          What are the names of all categories?   
1  How many products are there in each category?   
2  Who are the suppliers with the most products?   

                                               Query  \
0  SELECT category_name\n                FROM cat...   
1  SELECT categories.category_name, COUNT(product...   
2  SELECT supplier_id, COUNT(*) as product_count\...   

                                              Result Err

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['Generated_SQL_Query'] = generated_sql_queries


In [12]:
# 11. Use all generated prompts that resulted in valid SQL queries and filter out the rest

df1_good_results = df1_results[df1_results['Error'].isnull() | (df1_results['Error'] == None)]
print(f"Number of successful queries: {len(df1_good_results)}")

df1_bad_results = df1_results[df1_results['Error'].notnull() | (df1_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df1_bad_results)}")

df1_good_clean_results = df1_good_results
df1_good_clean_results['Result'] = df1_good_results['Result'].astype(str)
df1_good_clean_results.to_parquet('df1_good_results.parquet')
#   df1_good_results = pd.read_parquet('df1_good_results.parquet')


Number of successful queries: 31
Number of unsuccessful queries: 9


In [13]:
# 12. Use golden dataset to run test with larger LLM

MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0" # "anthropic.claude-3-haiku-20240307-v1:0" # "mistral.mixtral-8x7b-instruct-v0:1" "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

df2 = runBedrockBatchJob(MODEL_ID, df_good_results)


# Test generated SQL queries and verify they work
results = []

if SQL_DATABASE == 'LOCAL':
    # Create a SQLite database connection
    conn = sqlite3.connect('routedb.db')
    cursor = conn.cursor()

for row in df2.itertuples():
    statement = extract_with_regex(row.Generated_SQL_Query, SQL_PATTERN)
    # print(f'SQL statement: {statement}')
    error = None
    try:
        
        if SQL_DATABASE == 'LOCAL':
            # Use local SQL lite
            
            # Replace PostgreSQL-specific syntax with SQLite equivalents
            
            try:
                cursor.execute(statement)
                # Fetch all rows from the result
                result = cursor.fetchall()
                # print(result)

            except sqlite3.Error as e:
                print(f"Error executing statement: {e}")
                error = e
                # print(f"Statement: {statement}")

        else:
            # Use Athena if AWS Glue Schema is used
            result = execute_athena_query(DATABASE, row.Generated_SQL_Query)
        
    except ClientError as e:
        error = e

    results.append({'Question': row.Question,'Query': statement, 'Result': result, 'Error': error, 'Context': row.Context})

if SQL_DATABASE == 'LOCAL':
    # close the connection
    conn.close()

df2_results = pd.DataFrame(results)
print(df2_results.head(3))

Error executing statement: no such column: country
Error executing statement: no such column: day
Error executing statement: no such column: employee_id
Error executing statement: no such column: day
                                        Question  \
0          What are the names of all categories?   
1  How many products are there in each category?   
2  Who are the suppliers with the most products?   

                                               Query  \
0             SELECT category_name\nFROM categories;   
1  SELECT c.category_name, COUNT(p.product_id) AS...   
2  SELECT s.company_name, COUNT(p.product_id) AS ...   

                                              Result Error  \
0  [(Beverages,), (Condiments,), (Confections,), ...  None   
1  [(Beverages, 12), (Condiments, 12), (Confectio...  None   
2  [(Specialty Biscuits, Ltd., 5), (Plutzer Leben...  None   

                                             Context  
0  --\n    -- PostgreSQL database dump\n    --\n\...  
1  --\n

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['Generated_SQL_Query'] = generated_sql_queries


In [24]:
# 13. Use all generated prompts that resulted in valid SQL queries and filter out the rest

df2_good_results = df2_results[df2_results['Error'].isnull() | (df2_results['Error'] == None)]
print(f"Number of successful queries: {len(df2_good_results)}")

df2_bad_results = df2_results[df2_results['Error'].notnull() | (df2_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df2_bad_results)}")

df2_good_clean_results = df2_good_results
df2_good_clean_results['Result'] = df2_good_results['Result'].astype(str)
df2_good_clean_results.to_parquet('df2_good_results.parquet')
#   df1_good_results = pd.read_parquet('df1_good_results.parquet')

Number of successful queries: 36
Number of unsuccessful queries: 4


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df2_good_clean_results['Result'] = df2_good_results['Result'].astype(str)


In [27]:
# 12a. Use LLM as a Judge to grade generated SQL from smaller LLM 
MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0" # "mistral.mixtral-8x7b-instruct-v0:1" "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

# use helper class for threaded API calls, reduce max_tokens to avoid throttling with Sonet 3.5
wrapper = BedrockLLMWrapper(debug=False, model_id=MODEL_ID, max_token_count= 512)

prompts_list = []
for row in df1.itertuples():
    
    prompt = build_grader_prompt(build_sqlquerygen_prompt(row.Question, row.Context), row.Generated_SQL_Query)
    prompts_list.append(prompt)
# [result_text,usage,query_time]
results = wrapper.generate_threaded(prompts_list)

formatted_results = []
for i,g in enumerate(results):
    reasoning = extract_with_regex(results[i][0], REASONING_PATTERN)
    score =  extract_with_regex(results[i][0], SCORE_PATTERN)
    correctness =  extract_with_regex(results[i][0], CORRECTNESS_PATTERN)
    formatted_results.append({ "reasoning": reasoning, "score": score , "correctness": correctness})

evaluated_df = pd.DataFrame(formatted_results) 

# merge df1 dataframe with columns from evaluated_df
df1 = pd.concat([df1, evaluated_df], axis=1)

print(df1.head(1).to_string())

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             reasoning score correctness
0  The generated SQL query:\n\n<SQL>\nSELECT category_name\nFROM categories;\n</SQL>\n\nThis query is syntactically correct and will retrieve all the category names from the "categories" table. It directly a

In [28]:
# 12b. Use LLM as a Judge to grade generated SQL from bigger LLM
MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0" 

# use helper class for threaded API calls, reduce max_tokens to avoid throttling with Sonet 3.5
wrapper = BedrockLLMWrapper(debug=False, model_id=MODEL_ID, max_token_count= 512)

prompts_list = []
for row in df2.itertuples():
    prompt = build_grader_prompt(build_sqlquerygen_prompt(row.Question, row.Context), row.Generated_SQL_Query)
    prompts_list.append(prompt)
# [result_text,usage,query_time]
results = wrapper.generate_threaded(prompts_list)

formatted_results = []
for i,g in enumerate(results):
    reasoning = extract_with_regex(results[i][0], REASONING_PATTERN)
    score =  extract_with_regex(results[i][0], SCORE_PATTERN)
    correctness =  extract_with_regex(results[i][0], CORRECTNESS_PATTERN)
    formatted_results.append({ "reasoning": reasoning, "score": score , "correctness": correctness})

evaluated_df2 = pd.DataFrame(formatted_results) 

# merge df1 dataframe with columns from evaluated_df
df2 = pd.concat([df2, evaluated_df2], axis=1)



In [34]:
# 12c. Review results: Pecentage correct

print(df1.head().to_dict())

# TBD: UserWarning: DataFrame columns are not unique, some columns will be omitted.

# percentage_correct = df1['correctness'].value_counts(normalize=True)['correct'] * 100
# print(f"Percentage correct for smaller LLM: {percentage_correct:.2f}%")

# percentage_correct = df2['correctness'].value_counts(normalize=True)['correct'] * 100
# print(f"Percentage correct for larger LLM: {percentage_correct:.2f}%")



  print(df1.head().to_dict())


In [None]:
# Calculate Text-to-SQL metrics:
# 1) Execution Accuracy (EX):  compares the generated SQL query to the labeled SQL query to determine if its a match or not. 
# 2) Exact Set Match Accuracy (EM) – did the returned result set actually answer the question, regardless of how the query was written
# 3) Valid Efficiency Score (VES)




In [None]:
#12d. Review results: Sample a subsection of 3 incorrect responses
from IPython.display import display, HTML

# Assuming you have a dataframe called 'df' with a column called 'result'
incorrect_rows = df1[df1['correctness'] == 'incorrect'].sample(n=3)

from IPython.display import display, HTML

# Convert the dataframe to an HTML table
table_html = incorrect_rows.to_html(index=False, classes='table table-striped')

# Display the HTML table
display(HTML(table_html))

In [None]:
#12e. Review results: Sample a subsection of 3 correct responses
from IPython.display import display, HTML

# Assuming you have a dataframe called 'df' with a column called 'result'
incorrect_rows = df1[df1['correctness'] == 'correct'].sample(n=3)

from IPython.display import display, HTML

# Convert the dataframe to an HTML table
table_html = incorrect_rows.to_html(index=False, classes='table table-striped')

# Display the HTML table
display(HTML(table_html))

### Train Classifier

In [None]:
# Let us assume that if the score is >= 4, we will route to the small LLM model (indicating the response quality is good enough); 
# otherwise, we will route to the large LLM model. Under this assumption, the data distribution looks like this

train_df["routing_label"] = train_df["mixtral_score"].apply(
    lambda x: 1 if x >= 4 else 0
)

visualize_label_distribution(train_df, key="routing_label")

# For classification tasks, it's recommended to train on label-balanced datasets to ensure models are not biased to a specific label. 
# We will balance the dataset based on routing_label, as this is the label of primary interest.
from src.utils import balance_dataset

balanced_train_df = balance_dataset(train_df, key="routing_label")

print(f"Train size: {len(balanced_train_df)}")

In [None]:
n_sample = 20
output_file = "/mnt/user_storage/train_data_sample.jsonl"

subsampled_df = balanced_train_df.sample(n=n_sample, random_state=42)
subsampled_df.to_json(output_file, orient="records", lines=True)



### FINE-TUNING JOB

In [None]:
# TBD: create fine-tuning job

In [None]:
# TBD: test scoring/routing prediction

### EVALUATION

In [None]:
# TBD: evaluate routing from accuracy,latency, and cost perspective

### Conclusion
In this tutorial, we have successfully built and evaluated a finetuned-LLM router. 
We generated synthetic labeled data using the LLM-as-a-judge method to train the model, finetuned an LLM classifier using Amazon Bedrock's API, 
and conducted offline evaluation.

### Sources

https://github.com/lm-sys/RouteLLM

https://medium.com/@learngrowthrive.fast/routellm-achieves-90-gpt-4-quality-at-80-lower-cost-6686e5f46e2a

https://medium.com/ai-insights-cobet/beyond-basic-chatbots-how-semantic-router-is-changing-the-game-783dd959a32d

https://medium.com/@bhawana.prs/semantic-routes-in-llms-to-make-chatbots-more-accurate-d99c17e30487


popular benchmarks: MT Bench, MMLU, and GSM8K.

* Semantic routing: Using a vector analysis to route the query to the closest “cluster”
https://github.com/aurelio-labs/semantic-router

* Prompt Chaining: Similar to what has been implemented inside Bedrock agents, and LangChain’s Custom function, these use an small LLM to analyze the question and route it to the next part of the chain. https://aws.amazon.com/blogs/machine-learning/enhance-conversational-ai-with-advanced-routing-techniques-with-amazon-bedrock/
You can optimize this by having the “router” model answer directly simple questions instead of routing them to another model.

* Intent Classification: Creating a custom model, similar to ROHF or Rerankers to classify the query and route it to the right LLM.  
https://medium.com/aimonks/intent-classification-generative-ai-based-application-architecture-3-79d2927537b4

https://www.anyscale.com/blog/building-an-llm-router-for-high-quality-and-cost-effective-responses

https://github.com/aws-samples/amazon-bedrock-samples/blob/main/function-calling/function_calling_text2SQL_converse_bedrock_streamlit.py

https://github.com/aws-samples/amazon-bedrock-samples/tree/main/rag-solutions/sql-query-generator



### SCRATCHPAD

In [None]:
# Use the native inference API to send a text message to Anthropic Claude.

import boto3
import json

from botocore.exceptions import ClientError

# Create a Bedrock Runtime client in the AWS Region of your choice.
client = boto3.client("bedrock-runtime", region_name="us-east-1")

# Set the model ID, e.g., Claude 3 Haiku.
model_id = "anthropic.claude-3-5-sonnet-20240620-v1:0"

# Define the prompt for the model.
prompt = "Describe the purpose of a 'hello world' program in one line."

# Format the request payload using the model's native structure.
native_request = {
    "anthropic_version": "bedrock-2023-05-31",
    "max_tokens": 512,
    "temperature": 0.5,
    "messages": [
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt}],
        }
    ],
}

# Convert the native request to JSON.
request = json.dumps(native_request)

try:
    # Invoke the model with the request.
    response = client.invoke_model(modelId=model_id, body=request)

except (ClientError, Exception) as e:
    print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
    exit(1)

# Decode the response body.
model_response = json.loads(response["body"].read())

# Extract and print the response text.
response_text = model_response["content"][0]["text"]
print(response_text)


