## Amazon Bedrock Text-to-SQL (Tanner)

### Intro and Goal
This Jupyter Notebook is designed to illustrate a few-shot Text-to-SQL approach on the Northwind database.

The goal is to take a user prompt along with a SQL database schema, supplement the prompt the prompt with relevant samples, and then generate a corresponding SQL query.

### Steps
1. Download SQL schema
2. Download ground truth dataset comprised of questions and SQL queries for a given database (e.g. Northwind)
3. Generate and run SQL queries

### Set Environment Variables

In [1]:
# 1. Import necessary libraries and load environment variables

from dotenv import load_dotenv, find_dotenv
import os
import pandas as pd
# loading environment variables that are stored in local file
local_env_filename = 'dev.env'
load_dotenv(find_dotenv(local_env_filename),override=True)

os.environ['REGION'] = os.getenv('REGION')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN')
os.environ['SQL_DATABASE'] = os.getenv('SQL_DATABASE') # LOCAL, SQLALCHEMY, REDSHIFT
os.environ['SQL_DIALECT'] = os.getenv('SQL_DIALECT') # SQlite, PostgreSQL


REGION = os.environ['REGION']
HF_TOKEN = os.environ['HF_TOKEN']
SQL_DATABASE = os.environ['SQL_DATABASE']
SQL_DIALECT = os.environ['SQL_DIALECT']

print(f"Using database: {SQL_DATABASE} with sql dialect: {SQL_DIALECT}")

# File path to ground truth dataset
file_path = "./data/ground_truth.jsonl"
ground_truth_df = pd.read_json(file_path, lines=True)

Using database: SQLALCHEMY with sql dialect: PostgreSQL


In [2]:
# 2. Initialize Chroma client from our persisted store
import chromadb
import boto3
from chromadb.config import Settings

# Initialize Chroma client from our persisted store
chroma_client = chromadb.PersistentClient(path="../data/chroma")

# Also initialize the bedrock client so we can call some embedding models!
session = boto3.Session()
bedrock = boto3.client('bedrock-runtime')

In [3]:
# 3. create chunks for the few-shot examples

from utils.splitter import DataFrameChunkingStrategy, RAGChunk

# TO BE UPDATED WITH ACTUAL SAMPLES
relevant_df = ground_truth_df[['Question', 'Query']]

chunking_strategy = DataFrameChunkingStrategy(relevant_df)

# Get the nodes from the chunker.
chunks: RAGChunk = chunking_strategy.process()


# print # of chunks
print(f"Number of chunks: {len(chunks)}")

# print first 3 chunks
print(f"First 3 chunks: {chunks[:3]}")


Processing complete. Created 124 chunks.
Number of chunks: 124
First 3 chunks: [RAGChunk(id_='bba4fa30-e443-4b0e-b180-03c79ca551cc', text='Question: What is the total number of customers?\nQuery: SELECT COUNT(*) FROM customers;', metadata={'index': '0'}), RAGChunk(id_='038a6719-6e4b-4191-a600-6407b97c25f7', text='Question: List all product names and their unit prices.\nQuery: SELECT product_name, unit_price FROM products;', metadata={'index': '1'}), RAGChunk(id_='7f45ccd3-bc88-4131-9d44-7bc9aff933ee', text='Question: Who are the top 5 customers by order count?\nQuery: SELECT c.company_name, COUNT(o.order_id) as order_count FROM customers c JOIN orders o ON c.customer_id = o.customer_id GROUP BY c.company_name ORDER BY order_count DESC LIMIT 5;', metadata={'index': '2'})]


In [4]:
# 4. Create embeddings for the few-shot examples

from chromadb.utils.embedding_functions import AmazonBedrockEmbeddingFunction
from utils.chroma import BaseRetrievalTask, ChromaDBRetrievalTask
# Define some experiment variables
TITAN_TEXT_EMBED_V2_ID: str = 'amazon.titan-embed-text-v2:0'
COLLECTION_NAME: str = 'sqlsamples_collection'

embedding_function = AmazonBedrockEmbeddingFunction(
    session=session,
    model_name=TITAN_TEXT_EMBED_V2_ID
)

retrieval_task: BaseRetrievalTask = ChromaDBRetrievalTask(
    chroma_client = chroma_client, 
    collection_name = COLLECTION_NAME,
    embedding_function = embedding_function,
    chunks = chunks
)

# If you've already created collection, comment out this line
retrieval_task.add_chunks_to_collection()

Finished Ingesting Chunks Into Collection


In [5]:
# 5. Text-to-SQL dynamic few-shot prompt
def build_sqlquerygen_prompt(user_question: str, sql_database_schema: str):
    sql_examples = retrieval_task.retrieve(query_text=user_question, n_results=3)
    sql_examples_str = "\n".join([example.document for example in sql_examples])
    prompt = """You are a SQL expert. You will be provided with the original user question and a SQL database schema. 
                Only return the SQL query and nothing else.
                Here is the original user question.
                <user_question>
                {user_question}
                </user_question>

                Here is the SQL database schema.
                <sql_database_schema>
                {sql_database_schema}
                </sql_database_schema>

                Here are some examples of SQL queries that answer similar questions:
                <sql_examples>
                {sql_examples}
                </sql_examples>
                
                Instructions:
                Generate a SQL query that answers the original user question.
                Use the schema, first create a syntactically correct {sql_dialect} query to answer the question. 
                Never query for all the columns from a specific table, only ask for a few relevant columns given the question.
                Always prefix table names with the "public." prefix.
                Pay attention to use only the column names that you can see in the schema description. 
                Be careful to not query for columns that do not exist. 
                Pay attention to which column is in which table. 
                Also, qualify column names with the table name when needed.
                If you cannot answer the user question with the help of the provided SQL database schema, 
                then output that this question question cannot be answered based of the information stored in the database.
                You are required to use the following format, each taking one line.
                Return the sql query inside the <SQL></SQL> tab.
                """.format(
                    user_question=user_question,
                    sql_database_schema=sql_database_schema,
                    sql_dialect=SQL_DIALECT,
                    sql_examples=sql_examples_str
                ) 
    return prompt

In [6]:
# 8b. Use ground truth to run test with larger LLM
from utils.bedrock import BedrockLLMWrapper
from utils.database import DatabaseUtil
from utils.util import Util
MODEL_ID = "anthropic.claude-3-haiku-20240307-v1:0" #"anthropic.claude-3-sonnet-20240229-v1:0"  "mistral.mixtral-8x7b-instruct-v0:1" "anthropic.claude-3-5-sonnet-20240620-v1:0" "meta.llama3-1-70b-instruct-v1:0"

# use helper class for threaded API calls
llm = BedrockLLMWrapper(model_id=MODEL_ID, max_token_count=500, region=REGION)
util = Util()
databaseutil = DatabaseUtil(
                        datasource_url=["https://d3q8adh3y5sxpk.cloudfront.net/sql-workshop/data/redshift-sourcedb.sql"],
                        sql_database= 'SQLALCHEMY',
                        region=REGION
        )

df2 = ground_truth_df
prompts_list = []
for row in df2.itertuples():
    prompt = build_sqlquerygen_prompt(row.Question, row.Context)
    prompts_list.append(prompt)
results = llm.generate_threaded(prompts_list, max_workers=5)

# Create a list to store the generated SQL queries
generated_sql_queries = []
for result in results:
    generated_sql_query = result[0]
    # print(f'generated_sql_query: {generated_sql_query}')
    generated_sql_queries.append(generated_sql_query)

# Add the new column 'Generated_SQL_Query' to df_results
df2['Generated_SQL_Query'] = generated_sql_queries

# Test generated SQL queries and verify they work
results = []

for row in df2.itertuples():
    statement = util.extract_with_regex(row.Generated_SQL_Query, util.SQL_PATTERN)
    # print(f'SQL statement: {statement}')
    error = None
    try:     
        result = databaseutil.run_sql(statement)

    except Exception as e:
        error = e

    results.append({'Question': row.Question,'Query': statement, 'Result': result, 'Error': error, 'ReferenceQuery': row.Query, 'Context': row.Context})

df2_results = pd.DataFrame(results)

# inspect first 3 results
print(df2_results.head(3))

# review successful/unsucessful queries
df2_good_results = df2_results[df2_results['Error'].isnull() | (df2_results['Error'] == None)]
print(f"Number of successful queries: {len(df2_good_results)}")

df2_bad_results = df2_results[df2_results['Error'].notnull() | (df2_results['Error'] == 'None')]
print(f"Number of unsuccessful queries: {len(df2_bad_results)}")

                                        Question  \
0         What is the total number of customers?   
1  List all product names and their unit prices.   
2    Who are the top 5 customers by order count?   

                                               Query  \
0             SELECT COUNT(*) FROM public.customers;   
1  SELECT public.products.product_name, public.pr...   
2  SELECT public.customers.company_name, COUNT(pu...   

                                              Result Error  \
0                                             [(91)]  None   
1  [(Chai, 18.0), (Chang, 19.0), (Aniseed Syrup, ...  None   
2  [(Save-a-lot Markets, 31), (Ernst Handel, 30),...  None   

                                      ReferenceQuery  \
0                    SELECT COUNT(*) FROM customers;   
1     SELECT product_name, unit_price FROM products;   
2  SELECT c.company_name, COUNT(o.order_id) as or...   

                                             Context  
0  CREATE TABLE categories (\n    cat