## Amazon Bedrock Text-to-SQL

### Intro and Goal
This Jupyter Notebook is designed to illustrate a zero-shot Text-to-SQL approach on the Northwind database.

The goal is to take a user prompt along with a SQL database schema, and then generate a corresponding SQL query.

### Steps
1. Download SQL schema
2. Download ground truth dataset comprised of questions and SQL queries for a our sample database
3. Generate and run SQL queries with a smaller LLM
4. Generate and run SQL queries with a larger LLM

In [5]:
# 1. Create a python environment

# !conda create -y --name bedrock-router-eval python=3.11.8
# !conda init && activate bedrock-router-eval
# !conda install -n bedrock-router-eval ipykernel --update-deps --force-reinstall -y
# !conda install -c conda-forge ipython-sql

## OR
# !python3 -m venv venv
# !source venv/bin/activate  # On Windows, use `venv\Scripts\activate`

# install ipykernel, which consists of IPython as well
# !pip install ipykernel
# create a kernel that can be used to run notebook commands inside the virtual environment
# !python3 -m ipykernel install --user --name=venv

In [1]:
# 2. Install dependencies

!pip install -r requirements.txt

Collecting transformers (from -r requirements.txt (line 19))
  Using cached transformers-4.45.1-py3-none-any.whl.metadata (44 kB)
Collecting safetensors>=0.4.1 (from transformers->-r requirements.txt (line 19))
  Using cached safetensors-0.4.5-cp312-cp312-macosx_11_0_arm64.whl.metadata (3.8 kB)
Using cached transformers-4.45.1-py3-none-any.whl (9.9 MB)
Using cached safetensors-0.4.5-cp312-cp312-macosx_11_0_arm64.whl (381 kB)
Installing collected packages: safetensors, transformers
Successfully installed safetensors-0.4.5 transformers-4.45.1


### Set Environment Variables

In [7]:
# 3. Import necessary libraries and load environment variables

from dotenv import load_dotenv, find_dotenv
import os

# loading environment variables that are stored in local file
local_env_filename = 'dev.env'
load_dotenv(find_dotenv(local_env_filename),override=True)

os.environ['REGION'] = os.getenv('REGION')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
os.environ['SQL_DATABASE'] = os.getenv('SQL_DATABASE') # LOCAL, SQLALCHEMY, REDSHIFT
os.environ['SQL_DIALECT'] = os.getenv('SQL_DIALECT') # SQlite, PostgreSQL


REGION = os.environ['REGION']
HF_TOKEN = os.environ['HF_TOKEN']
SQL_DATABASE = os.environ['SQL_DATABASE']
SQL_DIALECT = os.environ['SQL_DIALECT']

print(f"Using database: {SQL_DATABASE} with sql dialect: {SQL_DIALECT}")

Using database: SQLALCHEMY with sql dialect: PostgreSQL


In [8]:
# 4. Initialize database wrapper to run sql queries, get database schema, and create tables in database
from utils.database import DatabaseUtil

if SQL_DATABASE == 'SQLALCHEMY':
        # SQLALCHEMY
        databaseutil = DatabaseUtil(
                        datasource_url=["https://d3q8adh3y5sxpk.cloudfront.net/sql-workshop/data/redshift-sourcedb.sql"],
                        sql_database= 'SQLALCHEMY',
                        region=REGION
        )

if SQL_DATABASE == 'LOCAL':
        # LOCAL SqlLite
        databaseutil = DatabaseUtil(
                        datasource_url=["https://d3q8adh3y5sxpk.cloudfront.net/sql-workshop/data/redshift-sourcedb.sql"],
                        sql_database= 'LOCAL'
        )

result = databaseutil.create_database_tables()
print(result)
           
schema = databaseutil.get_schema_as_string()
print(schema)

# result = databaseutil.run_sql("SELECT * from public.customers")
# print(result)


SQL execution completed.
None
[{'table': 'customer_demographics', 'columns': {'customer_type_id': 'VARCHAR', 'customer_desc': 'TEXT'}, 'string_representation': 'Table: customer_demographics\nColumn: customer_type_id, Type: VARCHAR\nColumn: customer_desc, Type: TEXT'}, {'table': 'customer_customer_demo', 'columns': {'customer_id': 'VARCHAR', 'customer_type_id': 'VARCHAR'}, 'string_representation': 'Table: customer_customer_demo\nColumn: customer_id, Type: VARCHAR\nColumn: customer_type_id, Type: VARCHAR'}, {'table': 'customers', 'columns': {'customer_id': 'VARCHAR', 'company_name': 'VARCHAR(40)', 'contact_name': 'VARCHAR(30)', 'contact_title': 'VARCHAR(30)', 'address': 'VARCHAR(60)', 'city': 'VARCHAR(15)', 'region': 'VARCHAR(15)', 'postal_code': 'VARCHAR(10)', 'country': 'VARCHAR(15)', 'phone': 'VARCHAR(24)', 'fax': 'VARCHAR(24)'}, 'string_representation': 'Table: customers\nColumn: customer_id, Type: VARCHAR\nColumn: company_name, Type: VARCHAR(40)\nColumn: contact_name, Type: VARCHAR(

In [9]:
# 5. Download ground truth 

import requests
import os

# URL of the file to download
url = "https://d3q8adh3y5sxpk.cloudfront.net/sql-workshop/data/question_query_good_results.jsonl"

# Path to the local data folder
data_folder = "./data"

# Create the data folder if it doesn't exist
os.makedirs(data_folder, exist_ok=True)

# File name to save the downloaded file
file_name = "ground_truth.jsonl"

# Full path to save the file
file_path = os.path.join(data_folder, file_name)

# Send a GET request to download the file
response = requests.get(url)

# Save the file to the local data folder
with open(file_path, "wb") as file:
    file.write(response.content)

print(f"File downloaded and saved to {file_path}")

File downloaded and saved to ./data/ground_truth.jsonl


In [10]:
# 6. Validate - ensure ground truth SQL queries run successfully
import pandas as pd
import json

results = []

# Check if df exists in the current namespace
if 'df' not in globals():
    # If it doesn't exist, try to load it from a JSONL file
    if os.path.exists(file_path):
        # Load the dataframe from the JSONL file
        df = pd.read_json(file_path, lines=True)
        print("df loaded from JSONL file.")
    else:
        print(f"Error: JSONL file not found at {file_path}")
else:
    print(f"df with column names: {df.columns} already exists in memory.")

df.columns = df.columns.str.capitalize()


for row in df.itertuples():
    # print(row.query)
    error = None
    result = None
    try:
        
        result = databaseutil.run_sql(row.Query)
        
    except Exception as e:
        error = e

    results.append({'Question': row.Question,'Query': row.Query, 'Result': result, 'Error': error, 'Context': schema})


df_results = pd.DataFrame(results)

df_good_results = df_results[df_results['Error'].isnull() | (df_results['Error'] == None)]
print(f"Number of successful queries: {len(df_good_results)}")

df_bad_results = df_results[df_results['Error'].notnull() | (df_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df_bad_results)}")

df loaded from JSONL file.
Number of successful queries: 124
Number of unsuccessful queries: 0


In [11]:
# 7. Create Text-to-SQL zero-shot prompt
# This function builds a text-to-SQL zero-shot prompt for a given user question and SQL database schema.
# The prompt includes the original user question, the SQL database schema, and instructions for generating a SQL query.
# The sql_dialect parameter specifies the SQL dialect to be used in the generated SQL query (e.g., MySQL, SQLite, etc.).

def build_sqlquerygen_prompt(user_question: str, sql_database_schema: str):
    prompt = """You are a SQL expert. You will be provided with the original user question and a SQL database schema. 
                Only return the SQL query and nothing else.
                Here is the original user question.
                <user_question>
                {user_question}
                </user_question>

                Here is the SQL database schema.
                <sql_database_schema>
                {sql_database_schema}
                </sql_database_schema>
                
                Instructions:
                Generate a SQL query that answers the original user question.
                Use the schema, first create a syntactically correct {sql_dialect} query to answer the question. 
                Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
                Always prefix table names with the "public." prefix.
                Pay attention to use only the column names that you can see in the schema description. 
                Be careful to not query for columns that do not exist. 
                Pay attention to which column is in which table. 
                Also, qualify column names with the table name when needed.
                If you cannot answer the user question with the help of the provided SQL database schema, 
                then output that this question question cannot be answered based of the information stored in the database.
                You are required to use the following format, each taking one line.
                Return the sql query inside the <SQL></SQL> tab.
                """.format(
                    user_question=user_question,
                    sql_database_schema=sql_database_schema,
                    sql_dialect=SQL_DIALECT
                ) 
    return prompt

In [12]:
# 8a. Use ground truth to run test with smaller LLM
from utils.bedrock import BedrockLLMWrapper
from utils.util import Util
MODEL_ID = "mistral.mistral-7b-instruct-v0:2" # "mistral.mixtral-8x7b-instruct-v0:1" # "anthropic.claude-3-haiku-20240307-v1:0" # "mistral.mixtral-8x7b-instruct-v0:1" "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

# use helper class for threaded API calls
llm = BedrockLLMWrapper(model_id=MODEL_ID, max_token_count=500, region=REGION)
util = Util()
df1 = df_good_results
prompts_list = []
for row in df1.itertuples():
    prompt = build_sqlquerygen_prompt(row.Question, row.Context)
    prompts_list.append(prompt)
results = llm.generate_threaded(prompts_list)

# Create a list to store the generated SQL queries
generated_sql_queries = []
for result in results:
    generated_sql_query = result[0].replace("\\","") # workaround, switching to ConverseAPI introduced \ in Mistral response
    generated_sql_queries.append(generated_sql_query)

# Add the new column 'Generated_SQL_Query' to df_results
df1['Generated_SQL_Query'] = generated_sql_queries

# Test generated SQL queries and verify they work
results = []

for row in df1.itertuples():
    statement = util.extract_with_regex(row.Generated_SQL_Query, util.SQL_PATTERN)
    error = None
    result = None
    try:    
        result = databaseutil.run_sql(statement)

    except Exception as e:
        error = e

    results.append({'Question': row.Question,'Query': statement, 'Result': result, 'Error': error, 'ReferenceQuery': row.Query, 'Context': row.Context})

# inspect first 3 results
df1_results = pd.DataFrame(results)
print(df1_results.head(3))

# review successful/unsucessful queries
df1_good_results = df1_results[df1_results['Error'].isnull() | (df1_results['Error'] == None)]
print(f"Number of successful queries: {len(df1_good_results)}")

df1_bad_results = df1_results[df1_results['Error'].notnull() | (df1_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df1_bad_results)}")

                                        Question  \
0         What is the total number of customers?   
1  List all product names and their unit prices.   
2    Who are the top 5 customers by order count?   

                                               Query  \
0  SELECT COUNT(DISTINCT c.customer_id) as total_...   
1  SELECT products.product_name, products.unit_pr...   
2  SELECT c.customer_id, COUNT(o.order_id) as ord...   

                                              Result Error  \
0                                              [(0)]  None   
1  [(Chai, 18.0), (Chang, 19.0), (Aniseed Syrup, ...  None   
2  [(SAVEA, 31), (ERNSH, 30), (QUICK, 28), (HUNGO...  None   

                                      ReferenceQuery  \
0                    SELECT COUNT(*) FROM customers;   
1     SELECT product_name, unit_price FROM products;   
2  SELECT c.company_name, COUNT(o.order_id) as or...   

                                             Context  
0  [{'table': 'customer_demographics'

In [13]:
# 8b. Use ground truth to run test with larger LLM

MODEL_ID = "anthropic.claude-3-haiku-20240307-v1:0" #"anthropic.claude-3-sonnet-20240229-v1:0"  "mistral.mixtral-8x7b-instruct-v0:1" "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

# use helper class for threaded API calls
llm = BedrockLLMWrapper(model_id=MODEL_ID, max_token_count=500, region=REGION)
util = Util()
df2 = df_good_results
prompts_list = []
for row in df1.itertuples():
    prompt = build_sqlquerygen_prompt(row.Question, row.Context)
    prompts_list.append(prompt)
results = llm.generate_threaded(prompts_list, max_workers=8)

# Create a list to store the generated SQL queries
generated_sql_queries = []
for result in results:
    generated_sql_query = result[0]
    # print(f'generated_sql_query: {generated_sql_query}')
    generated_sql_queries.append(generated_sql_query)

# Add the new column 'Generated_SQL_Query' to df_results
df2['Generated_SQL_Query'] = generated_sql_queries

# Test generated SQL queries and verify they work
results = []

for row in df2.itertuples():
    statement = util.extract_with_regex(row.Generated_SQL_Query, util.SQL_PATTERN)
    # print(f'SQL statement: {statement}')
    error = None
    try:     
        result = databaseutil.run_sql(statement)

    except Exception as e:
        error = e

    results.append({'Question': row.Question,'Query': statement, 'Result': result, 'Error': error, 'ReferenceQuery': row.Query, 'Context': row.Context})

df2_results = pd.DataFrame(results)

# inspect first 3 results
print(df2_results.head(3))

# review successful/unsucessful queries
df2_good_results = df2_results[df2_results['Error'].isnull() | (df2_results['Error'] == None)]
print(f"Number of successful queries: {len(df2_good_results)}")

df2_bad_results = df2_results[df2_results['Error'].notnull() | (df2_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df2_bad_results)}")

                                        Question  \
0         What is the total number of customers?   
1  List all product names and their unit prices.   
2    Who are the top 5 customers by order count?   

                                               Query  \
0  SELECT COUNT(*) AS total_customers\nFROM publi...   
1  SELECT public.products.product_name, public.pr...   
2  SELECT public.customers.customer_id, COUNT(pub...   

                                              Result Error  \
0                                             [(91)]  None   
1  [(Chai, 18.0), (Chang, 19.0), (Aniseed Syrup, ...  None   
2  [(SAVEA, 31), (ERNSH, 30), (QUICK, 28), (HUNGO...  None   

                                      ReferenceQuery  \
0                    SELECT COUNT(*) FROM customers;   
1     SELECT product_name, unit_price FROM products;   
2  SELECT c.company_name, COUNT(o.order_id) as or...   

                                             Context  
0  [{'table': 'customer_demographics'

### Conclusion
As expected, we can observe that a larger LLM (e.g. Haiku) is able to produce valid SQL queries more successfully with zero-shot prompting compared to a smaller LLM (e.g. Mistral 7b).