### **Text to SQL: Natural Language to Athena Query Execution**

### Description:
This application allows users to interact with a database using natural language queries. 

Users can ask questions in plain English, and the system will convert the question into an SQL query. This query is then executed on an Amazon Athena database. Once the query is executed, the results are translated back into natural language, providing an easy-to-understand response. 

This workflow enables users to interact with databases without needing to write complex SQL, making it more accessible for non-technical users.

![Text to SQL](./text_to_sql.png)

In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
!pip install boto3 sqlalchemy langchain langchain-community langchain-aws PyAthena -qq

In [3]:
import json
with open("../Lab 1/variables.json", "r") as f:
    variables = json.load(f)

variables

{'accountNumber': '307297743176',
 'regionName': 'us-west-2',
 'collectionArn': 'arn:aws:aoss:us-west-2:307297743176:collection/h7cmj732p9d3v91spkhd',
 'collectionId': 'h7cmj732p9d3v91spkhd',
 'vectorIndexName': 'ws-index-',
 'bedrockExecutionRoleArn': 'arn:aws:iam::307297743176:role/advanced-rag-workshop-bedrock_execution_role-us-west-2',
 's3Bucket': '307297743176-us-west-2-advanced-rag-workshop',
 'kbFixedChunk': '4P6PBDDEGL',
 'kbSemanticChunk': 'IC3ZCBORXT',
 'kbCustomChunk': 'Q2T9CZ5VFA',
 'kbHierarchicalChunk': '1YIFVW0Z5E',
 'sagemakerLLMEndpoint': 'endpoint-llama-3-2-3b-instruct-2025-04-07-16-05-17',
 'guardrail_id': 'fe7ryshi7i7b',
 'guardrail_version': '1'}

# Data Query Assistant

In [4]:
# First, let's import all necessary libraries
import boto3
from sqlalchemy import create_engine
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_aws import ChatBedrock
from langchain_community.agent_toolkits import create_sql_agent, SQLDatabaseToolkit
from langchain.agents import AgentType
from langchain_core.prompts import ChatPromptTemplate
from langchain.callbacks.base import BaseCallbackHandler

# 1: Create a Callback Handler to Capture SQL Queries

In [5]:
class SQLHandler(BaseCallbackHandler):
    """Callback handler to capture the SQL query generated by the agent"""
    def __init__(self):
        self.sql_query = None

    def on_agent_action(self, action, **kwargs):
        if action.tool == "sql_db_query":
            self.sql_query = action.tool_input

# 2: Configure AWS Bedrock Model

In [6]:
# Set your AWS region
region = 'us-west-2'

# ChatBedrock Configuration
model_id = "anthropic.claude-3-5-haiku-20241022-v1:0"
model_kwargs = {
    "max_tokens": 4096,
    "temperature": 0.0,
    "top_k": 250,
    "top_p": 1
}

# 3: Set Up Athena Connection

In [7]:
# Athena Configuration
athena_workgroup = 'primary'
athena_query_result_location = f"s3://{variables['s3Bucket']}/athena-query-results/"
db_name = "retail"

print("athena_query_result_location")
print(athena_query_result_location)
def create_athena_engine(aws_region, athena_workgroup, athena_query_result_location, db_name):
    """Create a connection to AWS Athena"""
    athena_endpoint = f'athena.{aws_region}.amazonaws.com'
    athena_conn_string = (
        f"awsathena+rest://@{athena_endpoint}:443/{db_name}"
        f"?s3_staging_dir={athena_query_result_location}&work_group={athena_workgroup}"
    )
    athena_engine = create_engine(athena_conn_string, echo=False)
    return SQLDatabase(athena_engine)

athena_query_result_location
s3://307297743176-us-west-2-advanced-rag-workshop/athena-query-results/


# 4: Create a Filtered Database Class

In [8]:
# Override the SQLDatabase class to only include the desired tables
class FilteredSQLDatabase(SQLDatabase):
    def get_usable_table_names(self):
        return filtered_tables

# Filter tables to include only desired tables
def get_filtered_tables(db, desired_tables):
    all_tables = db.get_usable_table_names()
    return [table for table in all_tables if table in desired_tables]

# 5: Create Bedrock LLM

In [9]:
def create_bedrock_llm(model_id, model_kwargs):
    """Create a LangChain wrapper for AWS Bedrock"""
    bedrock_runtime = boto3.client(service_name="bedrock-runtime", region_name='us-west-2')
    return ChatBedrock(
        client=bedrock_runtime,
        model_id=model_id,
        model_kwargs=model_kwargs
    )

# 6: Create SQL Agent

In [10]:
def create_agent(db, llm):
    """Create a SQL agent with the database toolkit"""
    sql_toolkit = SQLDatabaseToolkit(llm=llm, db=db)
    agent_kwargs = {
        "handle_parsing_errors": True,
        "handle_sql_errors": True,
        "return_intermediate_steps": True
    }
    return create_sql_agent(
        llm=llm,
        toolkit=sql_toolkit,
        agent_executor_kwargs=agent_kwargs,
        agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        verbose=False  # Set to True to see the agent's thought process
    )

# 7: Define the Prompt Template

In [11]:
prompt_template = ChatPromptTemplate.from_messages([
    ("system", """
    You are an expert in Amazon Athena.
    You have access to the live database to query.
    To answer this question, 
        you will first need to get the schema of the relevant tables to see what columns are available.
    Then query the relevant tables in the database to come up with Final Answer.
    Do not assume any values for the data.
    Use [sql_db_list_tables] to get a list of tables in the database.
    Use [sql_db_schema] to the schema for these tables.
    Use [sql_db_query_checker] to validate the SQL query.
    Execute the query using [sql_db_query] tool and observe the output.
    Always provide the explanation and assumptions that you have made to come up with the output.
    For forecasting questions:
    - There won't be any data available for the future dates. So, identify historical data trends.
    - Use appropriate methods to forecast future values based on historical data.
    - Clearly explain the forecasting methodology and results.
    """
    ),
    ("human", "{context}"),
])

# 8: Function to Invoke the Agent

In [12]:
import time
import random

def invoke_agent(agent, question, max_retries=5, base_delay=5, max_delay=60):
    """
    Invoke the agent with a question and return both SQL and output with retry backoff.
    
    Args:
        agent: The agent to invoke
        question: The question to ask
        max_retries: Maximum number of retry attempts
        base_delay: Initial delay in seconds
        max_delay: Maximum delay in seconds
    
    Returns:
        Tuple of (sql_query, output) or (None, error_message) if all retries fail
    """
    handler = SQLHandler()
    prompt = prompt_template.invoke(question)
    
    attempt = 0
    last_exception = None
    
    while attempt <= max_retries:
        try:
            # Attempt to invoke the agent
            response = agent.invoke({"input": prompt}, {"callbacks": [handler]})
            output = response['output']
            sqlquery = handler.sql_query
            
            # If successful, return the results
            print(f"Agent invocation successful on attempt {attempt+1}")
            return sqlquery, output
            
        except Exception as e:
            last_exception = e
            attempt += 1
            
            # Check if we've reached the retry limit
            if attempt > max_retries:
                break
            
            # Check if this is a throttling error by examining the exception message
            is_throttling = "ThrottlingException" in str(e) or "Too many requests" in str(e)
            
            # Calculate delay with exponential backoff and jitter
            if is_throttling:
                # For throttling, use a more aggressive backoff strategy
                delay = min(base_delay * (4 ** (attempt - 1)) + random.uniform(0, 2), max_delay)
                print(f"Throttling detected on attempt {attempt}. Backing off for {delay:.2f} seconds...")
            else:
                # For other errors, use standard exponential backoff
                delay = min(base_delay * (2 ** (attempt - 1)) + random.uniform(0, 1), max_delay)
                print(f"Agent invocation attempt {attempt} failed. Retrying in {delay:.2f} seconds...")
                
            # Wait before retrying
            time.sleep(delay)
    
    # If we've exhausted all retries, provide a graceful message instead of raising an exception
    if "ThrottlingException" in str(last_exception) or "Too many requests" in str(last_exception):
        error_message = "Unable to get a response from the agent due to service throttling. The service is currently experiencing high demand. Please try again in a few minutes or reduce request frequency."
    else:
        error_message = f"Unable to get a response from the agent after {max_retries} attempts. Error: {str(last_exception)}"
    
    print(error_message)
    
    # Return None for SQL query and the error message as output
    return None, error_message

# 9: Put Everything Together

In [13]:
# Create Athena engine
db = create_athena_engine(region, athena_workgroup, athena_query_result_location, db_name)

# List of desired tables
desired_tables = ["retail_transactions"]
filtered_tables = get_filtered_tables(db, desired_tables)
filtered_db = FilteredSQLDatabase(db._engine)

# Create the LLM and agent
llm = create_bedrock_llm(model_id, model_kwargs)
agent = create_agent(filtered_db, llm)

# 10: Example Queries

In [14]:
# Example 1: Query the database
question = "How many transactions are recorded in the sales table?"
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
SELECT COUNT(*) as total_transactions FROM retail_transactions

Output:
There are 541,909 transactions recorded in the sales (retail_transactions) table.

Explanation of my process:
1. I first listed the available tables and found the "retail_transactions" table
2. I examined the table schema, which showed transaction details like invoice number, stock code, quantity, etc.
3. I noticed in the table properties that the initial record count was 569,467, but I wanted to verify the exact count
4. I ran a direct COUNT(*) query, which returned 541,909 transactions

The slight discrepancy between the initial table properties (569,467) and the actual count (541,909) could be due to various factors like data updates, data cleaning, or metadata not being perfectly synchronized. The COUNT(*) query provides the most accurate current count of transactions in the table.


In [15]:
# Example 2: More complex question
question = "Find the total number of transactions per country."
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
SELECT country, COUNT(*) as total_transactions 
FROM retail_transactions 
GROUP BY country 
ORDER BY total_transactions DESC 
LIMIT 10

Output:
Here are the total number of transactions per country, sorted from highest to lowest:

1. United Kingdom: 490,941 transactions
2. Germany: 9,450 transactions
3. France: 8,532 transactions
4. EIRE (Ireland): 8,142 transactions
5. Spain: 2,514 transactions
6. Netherlands: 2,356 transactions
7. Belgium: 2,061 transactions
8. Switzerland: 1,992 transactions
9. Portugal: 1,514 transactions
10. Blank/Unknown Country: 1,365 transactions

Key Observations:
- The United Kingdom dominates the transactions, accounting for over 95% of all transactions in the dataset.
- There's a significant drop-off in transaction volume after the United Kingdom.
- There are 1,365 transactions with no specified country, which might indicate missing or incomplete data.

The query grouped the transactions by country and cou

In [16]:
# Example 3: More complex question
question = "What is the most frequently purchased product in the UK?"
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
SELECT stockcode, description, SUM(quantity) as total_quantity
FROM retail_transactions
WHERE country = 'United Kingdom'
GROUP BY stockcode, description
ORDER BY total_quantity DESC
LIMIT 1

Output:
The most frequently purchased product in the UK is the "WORLD WAR 2 GLIDERS ASSTD DESIGNS" with stock code 84077, which was purchased with a total quantity of 48,326 units.

Explanation of my approach:
1. I first listed the available tables and found the retail_transactions table
2. I checked the schema of the table to understand its structure
3. I created a SQL query that:
   - Filtered transactions for the United Kingdom
   - Grouped the data by stock code and description
   - Summed the total quantity for each product
   - Ordered the results by total quantity in descending order
   - Limited the result to the top 1 to get the most frequently purchased product
4. I executed the query and retrieved the result

The query provides a clear 

In [18]:
# Example 3: More complex question
question = "What is the most frequently purchased product in the countries with 7th most transaction? What's the name of this country and what's name of the product?"
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
WITH CountryTransactionCounts AS (
    SELECT country, 
           COUNT(*) as transaction_count,
           DENSE_RANK() OVER (ORDER BY COUNT(*) DESC) as country_rank
    FROM retail_transactions
    GROUP BY country
),
SeventhMostTransactingCountry AS (
    SELECT country
    FROM CountryTransactionCounts
    WHERE country_rank = 7
),
ProductFrequency AS (
    SELECT description, 
           COUNT(*) as product_count,
           RANK() OVER (ORDER BY COUNT(*) DESC) as product_rank
    FROM retail_transactions
    WHERE country = (SELECT country FROM SeventhMostTransactingCountry)
    GROUP BY description
)
SELECT 
    (SELECT country FROM SeventhMostTransactingCountry) AS country,
    description AS most_frequent_product,
    product_count
FROM ProductFrequency
WHERE product_rank = 1;


Output:
- Country: Belgium (the 7th most transacting country)
- Most Frequently Purchased Product: "POSTAGE"
- Number of Transactions for this Produ