## Amazon Bedrock Text-to-SQL

### Intro and Goal
This Jupyter Notebook is designed to illustrate a zero-shot Text-to-SQL approach on the Northwind database.

The goal is to take a user prompt along with a SQL database schema, and then generate a corresponding SQL query.

### Steps
1. Download SQL schema
2. Download ground truth dataset comprised of questions and SQL queries for a our sample database (e.g. Northwind)
3. Generate and run SQL queries with a smaller LLM
4. Generate and run SQL queries with a larger LLM

In [15]:
# 1. Create a python environment

# !conda create -y --name bedrock-router-eval python=3.11.8
# !conda init && activate bedrock-router-eval
# !conda install -n bedrock-router-eval ipykernel --update-deps --force-reinstall -y
# !conda install -c conda-forge ipython-sql

## OR
# !python -m venv venv
# !source venv/bin/activate  # On Windows, use `venv\Scripts\activate`

In [42]:
# 2. Install dependencies

# !pip install -r requirements.txt

Collecting psycopg2 (from -r requirements.txt (line 20))
  Downloading psycopg2-2.9.9.tar.gz (384 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Building wheels for collected packages: psycopg2
  Building wheel for psycopg2 (pyproject.toml) ... [?25ldone
[?25h  Created wheel for psycopg2: filename=psycopg2-2.9.9-cp311-cp311-macosx_13_0_arm64.whl size=144388 sha256=6896d60e5e48135d32b88bbffdc89d7d70fdf4108e89bca5a2308c4ba8c79937
  Stored in directory: /Users/huthmac/Library/Caches/pip/wheels/ab/34/b9/78ebef1b3220b4840ee482461e738566c3c9165d2b5c914f51
Successfully built psycopg2
Installing collected packages: psycopg2
Successfully installed psycopg2-2.9.9


### Set Environment Variables

In [109]:
# 3. Import necessary libraries and load environment variables
import numpy as np
from scipy.spatial.distance import cdist
import json
from dotenv import load_dotenv, find_dotenv
import os
import boto3
import sqlite3
from pandas.io import sql
from botocore.config import Config
import pandas as pd
import io
import json
from io import StringIO
import sqlparse
import sqlite3
import time
import matplotlib.pyplot as plt
import re
import typing as t
from queue import Queue
from threading import Thread
from concurrent.futures import ThreadPoolExecutor, as_completed

# loading environment variables that are stored in local file
local_env_filename = 'dev.env'
load_dotenv(find_dotenv(local_env_filename),override=True)

os.environ['REGION'] = os.getenv('REGION')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
os.environ['SQL_DATABASE'] = os.getenv('SQL_DATABASE') # LOCAL, SQLALCHEMY, REDSHIFT
os.environ['SQL_DIALECT'] = os.getenv('SQL_DIALECT') # SQlite, PostgreSQL


REGION = os.environ['REGION']
HF_TOKEN = os.environ['HF_TOKEN']
SQL_DATABASE = os.environ['SQL_DATABASE']
SQL_DIALECT = os.environ['SQL_DIALECT']


# Create a SageMaker session
import sagemaker
sagemaker_session = sagemaker.Session()

# Get the default bucket
default_bucket = sagemaker_session.default_bucket()
print(f"Default SageMaker S3 bucket: {default_bucket}")

print(f"Using database: {SQL_DATABASE} with sql dialect: {SQL_DIALECT}")

Default SageMaker S3 bucket: sagemaker-us-east-1-026459568683
Using database: SQLALCHEMY with sql dialect: PostgreSQL


In [128]:
# 4. Definition of helper classes

# Bedrock LLM Class
class BedrockLLMWrapper():
    def __init__(self,
        model_id: str = 'anthropic.claude-3-haiku-20240307-v1:0', #'anthropic.claude-3-sonnet-20240229-v1:0',
        top_k: int = 5,
        top_p: int = 0.7,
        temperature: float = 0.0,
        max_token_count: int = 4000,
        max_attempts: int = 3,
        debug: bool = False

    ):

        self.model_id = model_id
        self.top_k = top_k
        self.top_p = top_p
        self.temperature = temperature
        self.max_token_count = max_token_count
        self.max_attempts = max_attempts
        self.debug = debug
        config = Config(
            retries = {
                'max_attempts': 10,
                'mode': 'standard'
            }
        )

        self.bedrock_runtime = boto3.client(service_name="bedrock-runtime", config=config, region_name=REGION)
        
    def generate(self,prompt):
        if self.debug: 
            print('entered BedrockLLMWrapper generate')
        attempt = 1

        message = {
            "role": "user",
            "content": [{"text": prompt}]
        }
        messages = []
        messages.append(message)
        
        # model specific inference parameters to use.
        if "anthropic" in self.model_id.lower():
            # system_prompts = [{"text": "You are a helpful AI Assistant."}]
            system_prompts = []
            # Base inference parameters to use.
            inference_config = {
                                "temperature": self.temperature, 
                                "maxTokens": self.max_token_count,
                                "stopSequences": ["\n\nHuman:"],
                                "topP": self.top_p,
                            }
            additional_model_fields = {"top_k": self.top_k}
        else:
            system_prompts = []
            # Base inference parameters to use.
            inference_config = {
                                "temperature": self.temperature, 
                                "maxTokens": self.max_token_count,
                            }
            additional_model_fields = {"top_k": self.top_k}

        if self.debug: 
            print("Sending:\nSystem:\n",system,"\nMessages:\n",str(messages))

        while True:
            try:

                # Send the message.
                response = self.bedrock_runtime.converse(
                    modelId=self.model_id,
                    messages=messages,
                    system=system_prompts,
                    inferenceConfig=inference_config,
                    additionalModelRequestFields=additional_model_fields
                )

                # Log token usage.
                text = response['output'].get('message').get('content')[0].get('text')
                usage = response['usage']
                latency = response['metrics'].get('latencyMs')

                if self.debug: 
                    print(f'text: {text} ; and token usage: {usage} ; and query_time: {latency}')    
                
                break
               
            except Exception as e:
                print("Error with calling Bedrock: "+str(e))
                attempt+=1
                if attempt>self.max_attempts:
                    print("Max attempts reached!")
                    result_text = str(e)
                    break
                else:#retry in 10 seconds
                    print("retry")
                    time.sleep(10)

        # return result_text
        return [text,usage,latency]

     # Threaded function for queue processing.
    def thread_request(self, q, results):
        while True:
            try:
                index, prompt = q.get(block=False)
                data = self.generate(prompt)
                results[index] = data
            except Queue.Empty:
                break
            except Exception as e:
                print(f'Error with prompt: {str(e)}')
                results[index] = str(e)
            finally:
                q.task_done()

    def generate_threaded(self, prompts, max_workers=15):
        results = [None] * len(prompts)
        
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_index = {executor.submit(self.generate, prompt): i for i, prompt in enumerate(prompts)}
            for future in as_completed(future_to_index):
                index = future_to_index[future]
                try:
                    results[index] = future.result()
                except Exception as exc:
                    print(f'Generated an exception: {exc}')
                    results[index] = str(exc)
        
        return results


# Utility class to get cost and visualizations 
class Util():
    def __init__(self,
        debug: bool = False

    ):
        self.debug = debug
    
    SCORE_PATTERN = r'<score>(.*?)</score>'
    REASONING_PATTERN = r'<thinking>(.*?)</thinking>'
    SQL_PATTERN = r'<[sS][qQ][lL]>(.*?)</[sS][qQ][lL]>'
    DIFFICULTY_PATTERN = r'<difficulty>(.*?)</difficulty>'
    USER_QUESTION_PATTERN = r'<user_question>(.*?)</user_question>'
    SQL_DATABASE_SCHEMA_PATTERN = r'<sql_database_schema>(.*?)</sql_database_schema>'
    SQL_DIALECT_PATTERN = r'<sql_dialect>(.*?)</sql_dialect>'


    def compare_results(self, answer_results1, answer_results2):


        # # Function to convert 'score' column
        def convert_score(df):
            # df['score'] = df['score'].map({'correct': 1, 'incorrect': 0})
            df['score'] = pd.to_numeric(df['score'], errors='coerce').fillna(0).astype(int)
            return df

        # Apply the conversion to both dataframes
        answer_results1 = convert_score(answer_results1)
        answer_results2 = convert_score(answer_results2)

        # Calculate the average values for each metric
        metrics = ['score', 'latency' ,'cost', 'ex_score', 'em_score','ves_score']
        
        avg_results1 = [answer_results1[metric].mean() for metric in metrics]
        avg_results2 = [answer_results2[metric].mean() for metric in metrics]

        # Calculate percentage change, handling divide-by-zero and infinite cases
        def safe_percent_change(a, b):
            if pd.isna(a) or pd.isna(b):
                return 0
            if a == 0 and b == 0:
                return 0
            elif a == 0:
                return 100  # Arbitrarily set to 100% increase if original value was 0
            else:
                change = (b - a) / a * 100
                return change if np.isfinite(change) else 0

        percent_change = [safe_percent_change(a, b) for a, b in zip(avg_results1, avg_results2)]

        # Set up the bar chart
        x = np.arange(len(metrics))
        width = 0.5

        fig, ax = plt.subplots(figsize=(12, 6))

        # Create the bars
        bars = ax.bar(x, percent_change, width)

        # Customize the chart
        ax.set_ylabel('Percentage Change (%)')
        ax.set_title('Percentage Change in Metrics (Results 2 vs Results 1)')
        ax.set_xticks(x)
        ax.set_xticklabels(metrics)

        # Add a horizontal line at y=0
        ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)

        # Add value labels on top of each bar
        def autolabel(rects):
            for rect in rects:
                height = rect.get_height()
                ax.annotate(f'{height:.2f}%',
                            xy=(rect.get_x() + rect.get_width() / 2, height),
                            xytext=(0, 3 if height >= 0 else -3),  # 3 points vertical offset
                            textcoords="offset points",
                            ha='center', va='bottom' if height >= 0 else 'top')

        autolabel(bars)

        # Color the bars based on positive (green) or negative (red) change
        # For latency & cost, reverse the color logic
        for bar, change, metric in zip(bars, percent_change, metrics):
            if metric == 'latency' or metric == 'cost':
                bar.set_color('green' if change <= 0 else 'red')
            else:
                bar.set_color('green' if change >= 0 else 'red')
            

        # Adjust layout and display the chart
        fig.tight_layout()
        plt.show()

    def visualize_distribution(self, df, key):
        # Check if 'score' column exists in the DataFrame
        if key not in df.columns:
            raise ValueError(f"The DataFrame does not contain a '{key}' column.")
        
        # Count the frequency of each score
        score_counts = df[key].value_counts().sort_index()
        
        # Create a bar chart
        plt.figure(figsize=(10, 6))
        plt.bar(score_counts.index, score_counts.values)
        
        # Customize the chart
        plt.title(f'Distribution of {key}')
        plt.xlabel(f'{key}')
        plt.ylabel('Frequency')
        plt.xticks(range(int(score_counts.index.min()), int(score_counts.index.max()) + 1))
        
        # Add value labels on top of each bar
        for i, v in enumerate(score_counts.values):
            plt.text(score_counts.index[i], v, str(v), ha='center', va='bottom')
        
        # Display the chart
        plt.tight_layout()
        plt.show()

    # Strip out the portion of the response with regex.
    def extract_with_regex(self, response, regex):
        matches = re.search(regex, response, re.DOTALL)
        # Extract the matched content, if any
        return matches.group(1).strip() if matches else None

    def calculate_cost(self, usage, model_id):
        '''
        Takes the usage tokens returned by Bedrock in input and output, and coverts to cost in dollars.
        '''
        
        input_token_haiku = 0.25/1000000
        output_token_haiku = 1.25/1000000
        input_token_sonnet = 3.00/1000000
        output_token_sonnet = 15.00/1000000
        input_token_opus = 15.00/1000000
        output_token_opus = 75.00/1000000
        
        input_token_titan_embeddingv1 = 0.1/1000000
        input_token_titan_embeddingv2 = 0.02/1000000
        input_token_titan_embeddingmultimodal = 0.8/1000000
        input_token_titan_premier = 0.5/1000000
        output_token_titan_premier = 1.5/1000000
        input_token_titan_lite = 0.15/1000000
        output_token_titan_lite = 0.2/1000000
        input_token_titan_express = 0.2/1000000
        output_token_titan_express = 0.6/1000000
       
        input_token_cohere_command = 0.15/1000000
        output_token_cohere_command = 2/1000000
        input_token_cohere_commandlight = 0.3/1000000
        output_token_cohere_commandlight = 0.6/1000000
        input_token_cohere_commandrplus = 3/1000000
        output_token_cohere_commandrplus = 15/1000000
        input_token_cohere_commandr = 5/1000000
        output_token_cohere_commandr = 1.5/1000000
        input_token_cohere_embedenglish = 0.1/1000000
        input_token_cohere_embedmultilang = 0.1/1000000

        input_token_llama3_8b = 0.4/1000000
        output_token_llama3_8b = 0.6/1000000
        input_token_llama3_70b = 2.6/1000000
        output_token_llama3_70b = 3.5/1000000

        input_token_mistral_8b = 0.15/1000000
        output_token_mistral_8b = 0.2/1000000
        input_token_mistral_large = 4/1000000
        output_token_mistral_large = 12/1000000

        cost = 0

        if 'haiku' in model_id:
            cost+= usage['inputTokens']*input_token_haiku
            cost+= usage['outputTokens']*output_token_haiku
        if 'sonnet' in model_id:
            cost+= usage['inputTokens']*input_token_sonnet
            cost+= usage['outputTokens']*output_token_sonnet
        if 'opus' in model_id:
            cost+= usage['inputTokens']*input_token_opus
            cost+= usage['outputTokens']*output_token_opus
        if 'amazon.titan-embed-text-v1' in model_id:
            cost+= usage['inputTokens']*input_token_titan_embeddingv1
        if 'amazon.titan-embed-text-v2' in model_id:
            cost+= usage['inputTokens']*input_token_titan_embeddingv2
        if 'cohere.embed-multilingual' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_embedmultilang
        if 'cohere.embed-english' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_embedenglish 
        if 'meta.llama3-8b-instruct' in model_id:
            cost+= usage['inputTokens']*input_token_llama3_8b
            cost+= usage['outputTokens']*output_token_llama3_8b
        if 'meta.llama3-70b-instruct' in model_id:
            cost+= usage['inputTokens']*input_token_llama3_70b
            cost+= usage['outputTokens']*output_token_llama3_70b
        if 'cohere.command-text' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_command
            cost+= usage['outputTokens']*output_token_cohere_command
        if 'cohere.command-light-text' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_commandlight
            cost+= usage['outputTokens']*output_token_cohere_commandlight
        if 'cohere.command-r-plus' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_commandrplus
            cost+= usage['outputTokens']*output_token_cohere_commandrplus
        if 'cohere.command-r' in model_id:
            cost+= usage['inputTokens']*input_token_cohere_commandr
            cost+= usage['outputTokens']*output_token_cohere_commandr
        if 'amazon.titan-text-express' in model_id:
            cost+= usage['inputTokens']*input_token_titan_express
            cost+= usage['outputTokens']*output_token_titan_express
        if 'amazon.titan-text-lite' in model_id:
            cost+= usage['inputTokens']*input_token_titan_lite
            cost+= usage['outputTokens']*output_token_titan_lite
        if 'amazon.titan-text-premier' in model_id:
            cost+= usage['inputTokens']*input_token_titan_premier
            cost+= usage['outputTokens']*output_token_titan_premier
        if 'mistral.mixtral-8x7b-instruct-v0:1' in model_id:
            cost+= usage['inputTokens']*input_token_mistral_8b
            cost+= usage['outputTokens']*output_token_mistral_8b

        return cost

# Utility class to get database schema, create tables, and run SQL queries
import requests
import sqlite3
import re
from pyathena import connect
# import psycopg2
from sqlalchemy import create_engine, MetaData, text


class DatabaseUtil():
    def __init__(self,
        debug: bool = False,
        datasource_url: [] = ['https://d3q8adh3y5sxpk.cloudfront.net/sql-workshop/data/redshift-sourcedb.sql'],
        sql_database: str = 'LOCAL',
        sql_database_name: str = 'dev',
        region: str = 'us-east-1',
        s3_bucketname: str = ''

    ):
        self.debug = debug
        self.datasource_url = datasource_url
        self.sql_database = sql_database
        self.sql_database_name = sql_database_name
        self.region = region
        self.s3_bucketname = s3_bucketname

    # retrieve AWS secret for database connection
    def get_secret(self, secret_name):
        session = boto3.session.Session()
        client = session.client(service_name='secretsmanager', region_name=self.region)
        get_secret_value_response = client.get_secret_value(SecretId=secret_name)
        return get_secret_value_response

    def get_table_reflections(self, engine) -> MetaData:
    
        # Instantiate MetaData object
        metadata = MetaData()
        
        # Reflect the database schema with the engine
        metadata.reflect(bind=engine)
        
        return metadata

    def convert_reflection_to_dict(self, metadata: MetaData) -> dict:
        table_definitions: list[dict] = []
        for table_name in metadata.tables:
            definition = {}
            definition['table'] = table_name
            # The metadata.table[x].columns value is type sqlalchemy.sql.base.ReadOnlyColumnCollection
            # Lets convert it into something more usable. c.type returns a SQLAlchemy object so we convert to string.
            definition['columns'] = { c.name: str(c.type) for c in metadata.tables[table_name].columns }
        
            table_header = f"Table: {table_name}"
            columns_definition = '\n'.join([f"Column: {c.name}, Type: {c.type}" for c in metadata.tables[table_name].columns])
            string_representation = f"{table_header}\n{columns_definition}"
        
            definition['string_representation'] = string_representation
        
            table_definitions.append(definition)
        
        
        # The metadata table is a FacadeDict object which is immutable so we need to remove unwanted tables in the new list.
        table_names_to_exclude = set(['table_embedding', 'alembic_version'])
        table_definitions = [d for d in table_definitions if d['table'] not in table_names_to_exclude]

        return table_definitions
    
    def create_database_tables(self):
        # Download the SQL files
        
            # create local db and import northwind database
            for url in self.datasource_url:
                response = requests.get(url)
                sql_content = response.text
                # Split the SQL content into individual statements
                sql_statements = re.split(r';\s*$', sql_content, flags=re.MULTILINE)
                
                if self.sql_database == 'LOCAL':
                    try:
                        # Create a SQLite database connection
                        conn = sqlite3.connect('devdb.db')
                        cursor = conn.cursor()

                        # Execute each SQL statement
                        for statement in sql_statements:
                            # Skip empty statements
                            if statement.strip():
                                # print(f'statement: {statement}')
                                # Replace PostgreSQL-specific syntax with SQLite equivalents
                                statement = statement.replace('SERIAL PRIMARY KEY', 'INTEGER PRIMARY KEY AUTOINCREMENT')
                                statement = statement.replace('::int', '')
                                statement = statement.replace('::varchar', '')
                                statement = statement.replace('::real', '')
                                statement = statement.replace('::date', '')
                                statement = statement.replace('::boolean', '')
                                statement = statement.replace('public.', '')
                                statement = re.sub(r'WITH \(.*?\)', '', statement)
                                
                                try:
                                    cursor.execute(statement)
                                except sqlite3.Error as e:
                                    print(f"Error executing statement: {e}")

                        # Commit the changes and close the connection
                        conn.commit()
                        conn.close()

                        print("SQL execution completed.")
                    except Exception as e:
                        print(f"Error creating tables: {e}")
                        raise

                if self.sql_database == 'REDSHIFT':
                    try:
                        rdc = boto3.client('redshift-data')
                        get_secret_value_response = self.get_secret("RedshiftCreds")
                        # parse REDSHIFT_CLUSTER_DETAILS to extract WorkgroupName, Database, DbUser
                        WorkgroupName = json.loads(get_secret_value_response['SecretString']).get('workgroupname')
                        Database = json.loads(get_secret_value_response['SecretString']).get('workgroupname')
                        DbUser = json.loads(get_secret_value_response['SecretString']).get('username')

                        for statement in sql_statements:
                            try:        
                                rdc.execute_statement(
                                    WorkgroupName=WorkgroupName,
                                    Database=Database,
                                    DbUser=DbUser,
                                    Sql=statement
                                )
                                
                            except Exception as e:
                                print(f"Error executing statement: {e}")
                        print("SQL execution completed.")
                    except Exception as e:
                        print(f"Error creating tables: {e}")
                        raise
                
                if self.sql_database =='SQLALCHEMY':
                    # create tables in database
                    try:
                        # SQLALCHEMY_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{SQL_DATABASE_NAME}"
                        get_secret_value_response = self.get_secret("SQLALCHEMY_URL")
                        SQLALCHEMY_URL = get_secret_value_response['SecretString']

                        print(f"SQLALCHEMY_URL: {SQLALCHEMY_URL}")
                        engine = create_engine(SQLALCHEMY_URL)
                        with engine.connect() as connection:
                            # Execute each SQL statement
                            for statement in sql_statements:
                                # Skip empty statements
                                if statement.strip():
                                    # print(f'statement: {statement}')
                                    # Replace PostgreSQL-specific syntax with SQLite equivalents
                                    statement = statement.replace('SERIAL PRIMARY KEY', 'INTEGER PRIMARY KEY AUTOINCREMENT')
                                    statement = statement.replace('::int', '')
                                    statement = statement.replace('::varchar', '')
                                    statement = statement.replace('::real', '')
                                    statement = statement.replace('::date', '')
                                    statement = statement.replace('::boolean', '')
                                    statement = statement.replace('public.', '')
                                    statement = statement.replace('VARBYTE', 'bytea')
                                    statement = statement.replace('bpchar', 'varchar')
                                    
                                    statement = re.sub(r'WITH \(.*?\)', '', statement)
                                    
                                    try:
                                        connection.execute(text(statement))
                                    except Exception as e:
                                        print(f"Error executing statement: {e}")
                            connection.commit()
                            print("SQL execution completed.")    
                    except Exception as e:
                        print(f"Error creating tables: {e}")
                        raise


    def get_schema_as_string(self):
        if self.sql_database == 'LOCAL':
            db_path = 'devdb.db'          
            conn = sqlite3.connect(db_path)
            cursor = conn.cursor()

            # Query to get all table names
            cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
            tables = cursor.fetchall()

            schema_string = ""

            for table in tables:
                table_name = table[0]
                # Query to get the CREATE TABLE statement for each table
                cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}';")
                create_table_stmt = cursor.fetchone()[0]
                
                schema_string += f"{create_table_stmt};\n\n"

            conn.close()
            return schema_string

        if self.sql_database =='SQLALCHEMY':
            try:
                # Use SQLAlchemy if SQL Alchemy is used
                # SQLALCHEMY_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{SQL_DATABASE_NAME}"
                get_secret_value_response = self.get_secret("SQLALCHEMY_URL")
                SQLALCHEMY_URL = get_secret_value_response['SecretString']
                
                engine = create_engine(SQLALCHEMY_URL)
                metadata = self.get_table_reflections(engine)
                table_definitions = self.convert_reflection_to_dict(metadata)

                return table_definitions
                # with engine.connect() as connection:
                #     result = connection.execute("""""")
                #     return result.fetchall()
            except Exception as e:
                error = f"Error executing statement: {e}"
                print(error)

        if self.sql_database == "REDSHIFT":
            try:
                get_secret_value_response = self.get_secret("RedshiftCreds")
                # parse REDSHIFT_CLUSTER_DETAILS to extract WorkgroupName, Database, DbUser
                WorkgroupName = json.loads(get_secret_value_response['SecretString']).get('workgroupname')
                Database = json.loads(get_secret_value_response['SecretString']).get('workgroupname')
                DbUser = json.loads(get_secret_value_response['SecretString']).get('username')
                
                rdc = boto3.client('redshift-data')
                result = rdc.execute_statement(
                    WorkgroupName=WorkgroupName,
                    Database=Database,
                    DbUser=DbUser,
                    Sql=f"select * from pg_table_def where schemaname = 'public';"
                )
                return result
            except Exception as e:
                print(f"Error executing statement: {e}")
      
            
        if self.sql_database == 'GLUE':
            # use a Glue database
            table_names=None
            try:
                glue_client = boto3.client('glue', region_name=self.region)
                table_schema_list = []
                response = glue_client.get_tables(DatabaseName=self.sql_database_name)

                all_table_names = [table['Name'] for table in response['TableList']]

                if table_names:
                    table_names = [name for name in table_names if name in all_table_names]
                else:
                    table_names = all_table_names

                for table_name in table_names:
                    response = glue_client.get_table(DatabaseName=self.sql_database_name, Name=table_name)
                    columns = response['Table']['StorageDescriptor']['Columns']
                    schema = {column['Name']: column['Type'] for column in columns}
                    table_schema_list.append({"Table: {}".format(table_name): 'Schema: {}'.format(schema)})
            except Exception as e:
                print(f"Error: {str(e)}")
            return table_schema_list
        
    def run_sql(self, statement):
    
        if self.sql_database == 'LOCAL':
            try:
                # Create a SQLite database connection
                conn = sqlite3.connect('devdb.db')
                cursor = conn.cursor()

                cursor.execute(statement)
                # Fetch all rows from the result
                result = cursor.fetchall()
                conn.close()
                return result
            except sqlite3.Error as e:
                error = f"Error executing statement: {e}"
                raise
            
            finally:
                conn.close()
                
        if self.sql_database == 'GLUE':
            try:
                # Use Athena if AWS Glue Schema is used
                athenacursor = connect(s3_staging_dir=f"s3://{self.s3_bucketname}/athena/",
                                        region_name=self.region).cursor()
                athenacursor.execute(statement)
                result = pd.DataFrame(athenacursor.fetchall()).to_string(index=False)
                # convert df to string
                return result
            
            except Exception as e:
                error = f"Error executing statement: {e}"
                raise
        
        if self.sql_database == "REDSHIFT":
            try:
                get_secret_value_response = self.get_secret("RedshiftCreds")
                # parse REDSHIFT_CLUSTER_DETAILS to extract WorkgroupName, Database, DbUser
                WorkgroupName = json.loads(get_secret_value_response['SecretString']).get('workgroupname')
                Database = json.loads(get_secret_value_response['SecretString']).get('workgroupname')
                DbUser = json.loads(get_secret_value_response['SecretString']).get('username')

                rdc = boto3.client('redshift-data')
                result = rdc.execute_statement(
                    WorkgroupName=WorkgroupName,
                    Database=Database,
                    DbUser=DbUser,
                    Sql=statement
                )
                return result
                
            except Exception as e:
                print(f"Error executing statement: {e}")

        if self.sql_database =='SQLALCHEMY':
            try:
                # SQLALCHEMY_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{SQL_DATABASE_NAME}"
                get_secret_value_response = self.get_secret("SQLALCHEMY_URL")
                SQLALCHEMY_URL = get_secret_value_response['SecretString']
                
                engine = create_engine(SQLALCHEMY_URL)
                with engine.connect() as connection:
                    result = connection.execute(text(statement))
                    return result.fetchall()
            except Exception as e:
                error = f"Error executing statement: {e}"
                raise

In [111]:
# 5. Initialize database wrapper to run sql queries, get database schema, and create tables in database

if SQL_DATABASE == 'SQLALCHEMY':
        # SQLALCHEMY
        databaseutil = DatabaseUtil(
                        datasource_url=["https://d3q8adh3y5sxpk.cloudfront.net/sql-workshop/data/redshift-sourcedb.sql"],
                        sql_database= 'SQLALCHEMY'
        )

if SQL_DATABASE == 'LOCAL':
        # LOCAL SqlLite
        databaseutil = DatabaseUtil(
                        datasource_url=["https://d3q8adh3y5sxpk.cloudfront.net/sql-workshop/data/redshift-sourcedb.sql"],
                        sql_database= 'LOCAL'
        )

result = databaseutil.create_database_tables()
print(result)
           
schema = databaseutil.get_schema_as_string()
print(schema)

# result = databaseutil.run_sql("SELECT * from public.customers")
# print(result)


SQLALCHEMY_URL: postgresql://masteruser:rkKpkHN982shTOqf@workshop-database.cluster-cehkbhhtrbhr.us-east-1.rds.amazonaws.com:5432/dev
SQL execution completed.
None
[{'table': 'customer_demographics', 'columns': {'customer_type_id': 'VARCHAR', 'customer_desc': 'TEXT'}, 'string_representation': 'Table: customer_demographics\nColumn: customer_type_id, Type: VARCHAR\nColumn: customer_desc, Type: TEXT'}, {'table': 'customer_customer_demo', 'columns': {'customer_id': 'VARCHAR', 'customer_type_id': 'VARCHAR'}, 'string_representation': 'Table: customer_customer_demo\nColumn: customer_id, Type: VARCHAR\nColumn: customer_type_id, Type: VARCHAR'}, {'table': 'customers', 'columns': {'customer_id': 'VARCHAR', 'company_name': 'VARCHAR(40)', 'contact_name': 'VARCHAR(30)', 'contact_title': 'VARCHAR(30)', 'address': 'VARCHAR(60)', 'city': 'VARCHAR(15)', 'region': 'VARCHAR(15)', 'postal_code': 'VARCHAR(10)', 'country': 'VARCHAR(15)', 'phone': 'VARCHAR(24)', 'fax': 'VARCHAR(24)'}, 'string_representation':

In [112]:
# 6. Download ground truth 

import requests
import os

# URL of the file to download
url = "https://d3q8adh3y5sxpk.cloudfront.net/sql-workshop/data/question_query_good_results.jsonl"

# Path to the local data folder
data_folder = "./data"

# Create the data folder if it doesn't exist
os.makedirs(data_folder, exist_ok=True)

# File name to save the downloaded file
file_name = "ground_truth.jsonl"

# Full path to save the file
file_path = os.path.join(data_folder, file_name)

# Send a GET request to download the file
response = requests.get(url)

# Save the file to the local data folder
with open(file_path, "wb") as file:
    file.write(response.content)

print(f"File downloaded and saved to {file_path}")

File downloaded and saved to ./data/ground_truth.jsonl


In [113]:
# 7. Validate/ensure ground truth SQL queries run successfully
import pandas as pd
import json

results = []

# Check if df exists in the current namespace
if 'df' not in globals():
    # If it doesn't exist, try to load it from a JSONL file
    if os.path.exists(file_path):
        # Load the dataframe from the JSONL file
        df = pd.read_json(file_path, lines=True)
        print("df loaded from JSONL file.")
    else:
        print(f"Error: JSONL file not found at {file_path}")
else:
    print(f"df with column names: {df.columns} already exists in memory.")

df.columns = df.columns.str.capitalize()


for row in df.itertuples():
    # print(row.query)
    error = None
    result = None
    try:
        
        result = databaseutil.run_sql(row.Query)
        
    except Exception as e:
        error = e

    results.append({'Question': row.Question,'Query': row.Query, 'Result': result, 'Error': error, 'Context': schema})


df_results = pd.DataFrame(results)

df_good_results = df_results[df_results['Error'].isnull() | (df_results['Error'] == None)]
print(f"Number of successful queries: {len(df_good_results)}")

df_bad_results = df_results[df_results['Error'].notnull() | (df_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df_bad_results)}")

df with column names: Index(['Question', 'Query', 'Result', 'Error', 'Context'], dtype='object') already exists in memory.
Number of successful queries: 124
Number of unsuccessful queries: 0


In [122]:
# 8. Create Text-to-SQL zero-shot prompt
# This function builds a text-to-SQL zero-shot prompt for a given user question and SQL database schema.
# The prompt includes the original user question, the SQL database schema, and instructions for generating a SQL query.
# The sql_dialect parameter specifies the SQL dialect to be used in the generated SQL query (e.g., MySQL, SQLite, etc.).


# sql_dialect = awsathena or SQLite
def build_sqlquerygen_prompt(user_question: str, sql_database_schema: str):
    prompt = """You are a SQL expert. You will be provided with the original user question and a SQL database schema. 
                Only return the SQL query and nothing else.
                Here is the original user question.
                <user_question>
                {user_question}
                </user_question>

                Here is the SQL database schema.
                <sql_database_schema>
                {sql_database_schema}
                </sql_database_schema>
                
                Instructions:
                Generate a SQL query that answers the original user question.
                Use the schema, first create a syntactically correct {sql_dialect} query to answer the question. 
                Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
                Always prefix table names with the "public." prefix.
                Pay attention to use only the column names that you can see in the schema description. 
                Be careful to not query for columns that do not exist. 
                Pay attention to which column is in which table. 
                Also, qualify column names with the table name when needed.
                If you cannot answer the user question with the help of the provided SQL database schema, 
                then output that this question question cannot be answered based of the information stored in the database.
                You are required to use the following format, each taking one line:
                Return the sql query inside the <SQL></SQL> tab.
                """.format(
                    user_question=user_question,
                    sql_database_schema=sql_database_schema,
                    sql_dialect=SQL_DIALECT
                ) 
    return prompt

In [129]:
# 9a. Use ground truth to run test with smaller LLM

MODEL_ID = "mistral.mixtral-8x7b-instruct-v0:1" # "anthropic.claude-3-haiku-20240307-v1:0" # "mistral.mixtral-8x7b-instruct-v0:1" "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

# use helper class for threaded API calls
wrapper = BedrockLLMWrapper(debug=False, model_id=MODEL_ID, max_token_count=500)
util = Util()
df1 = df_good_results
prompts_list = []
for row in df1.itertuples():
    prompt = build_sqlquerygen_prompt(row.Question, row.Context)
    prompts_list.append(prompt)
results = wrapper.generate_threaded(prompts_list)

# Create a list to store the generated SQL queries
generated_sql_queries = []
for result in results:
    generated_sql_query = result[0].replace("\\","") # workaround, switching to ConverseAPI introduced \ in Mistral response
    generated_sql_queries.append(generated_sql_query)

# Add the new column 'Generated_SQL_Query' to df_results
df1['Generated_SQL_Query'] = generated_sql_queries


# Test generated SQL queries and verify they work
results = []

for row in df1.itertuples():
    statement = util.extract_with_regex(row.Generated_SQL_Query, util.SQL_PATTERN)
    error = None
    result = None
    try:
        
        result = databaseutil.run_sql(statement)

    except Exception as e:
        error = e

    results.append({'Question': row.Question,'Query': statement, 'Result': result, 'Error': error, 'ReferenceQuery': row.Query, 'Context': row.Context})


# inspect first 3 results
df1_results = pd.DataFrame(results)
print(df1_results.head(3))

# review successful/unsucessful queries
df1_good_results = df1_results[df1_results['Error'].isnull() | (df1_results['Error'] == None)]
print(f"Number of successful queries: {len(df1_good_results)}")

df1_bad_results = df1_results[df1_results['Error'].notnull() | (df1_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df1_bad_results)}")

                                        Question  \
0         What is the total number of customers?   
1  List all product names and their unit prices.   
2    Who are the top 5 customers by order count?   

                                               Query  \
0            SELECT COUNT(*)\nFROM public.customers;   
1  SELECT products.product_name, products.unit_pr...   
2  SELECT c.customer_id, COUNT(o.order_id) as ord...   

                                              Result Error  \
0                                             [(91)]  None   
1  [(Chai, 18.0), (Chang, 19.0), (Aniseed Syrup, ...  None   
2  [(SAVEA, 31), (ERNSH, 30), (QUICK, 28), (HUNGO...  None   

                                      ReferenceQuery  \
0                    SELECT COUNT(*) FROM customers;   
1     SELECT product_name, unit_price FROM products;   
2  SELECT c.company_name, COUNT(o.order_id) as or...   

                                             Context  
0  [{'table': 'customer_demographics'

In [130]:
# 9b. Use ground truth to run test with larger LLM

MODEL_ID = "anthropic.claude-3-haiku-20240307-v1:0" #"anthropic.claude-3-sonnet-20240229-v1:0"  "mistral.mixtral-8x7b-instruct-v0:1" "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

# use helper class for threaded API calls
wrapper = BedrockLLMWrapper(debug=False, model_id=MODEL_ID, max_token_count=500)
util = Util()
df2 = df_good_results
prompts_list = []
for row in df1.itertuples():
    prompt = build_sqlquerygen_prompt(row.Question, row.Context)
    prompts_list.append(prompt)
results = wrapper.generate_threaded(prompts_list, max_workers=8)

# Create a list to store the generated SQL queries
generated_sql_queries = []
for result in results:
    generated_sql_query = result[0]
    # print(f'generated_sql_query: {generated_sql_query}')
    generated_sql_queries.append(generated_sql_query)

# Add the new column 'Generated_SQL_Query' to df_results
df2['Generated_SQL_Query'] = generated_sql_queries

# Test generated SQL queries and verify they work
results = []

for row in df2.itertuples():
    statement = util.extract_with_regex(row.Generated_SQL_Query, util.SQL_PATTERN)
    # print(f'SQL statement: {statement}')
    error = None
    try:
        
        result = databaseutil.run_sql(statement)

    except Exception as e:
        error = e

    results.append({'Question': row.Question,'Query': statement, 'Result': result, 'Error': error, 'ReferenceQuery': row.Query, 'Context': row.Context})

df2_results = pd.DataFrame(results)

# inspect first 3 results
print(df2_results.head(3))

# review successful/unsucessful queries
df2_good_results = df2_results[df2_results['Error'].isnull() | (df2_results['Error'] == None)]
print(f"Number of successful queries: {len(df2_good_results)}")

df2_bad_results = df2_results[df2_results['Error'].notnull() | (df2_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df2_bad_results)}")

                                        Question  \
0         What is the total number of customers?   
1  List all product names and their unit prices.   
2    Who are the top 5 customers by order count?   

                                               Query  \
0  SELECT COUNT(*) AS total_customers\nFROM publi...   
1  SELECT public.products.product_name, public.pr...   
2  SELECT public.customers.customer_id, COUNT(pub...   

                                              Result Error  \
0                                             [(91)]  None   
1  [(Chai, 18.0), (Chang, 19.0), (Aniseed Syrup, ...  None   
2  [(SAVEA, 31), (ERNSH, 30), (QUICK, 28), (HUNGO...  None   

                                      ReferenceQuery  \
0                    SELECT COUNT(*) FROM customers;   
1     SELECT product_name, unit_price FROM products;   
2  SELECT c.company_name, COUNT(o.order_id) as or...   

                                             Context  
0  [{'table': 'customer_demographics'

### Conclusion
As expected, we can observe that a larger LLM (e.g. Haiku) is able to produce valid SQL queries slightly more successfully with zero-shot prompting compared to a smaller LLM (e.g. Mistral 8x7b).