### **Text to SQL: Natural Language to Athena Query Execution**

### Description:
This application allows users to interact with a database using natural language queries. 

Users can ask questions in plain English, and the system will convert the question into an SQL query. This query is then executed on an Amazon Athena database. Once the query is executed, the results are translated back into natural language, providing an easy-to-understand response. 

This workflow enables users to interact with databases without needing to write complex SQL, making it more accessible for non-technical users.

![Text to SQL](./text_to_sql.png)

In [34]:
import warnings
warnings.filterwarnings("ignore")

In [35]:
!pip install boto3 sqlalchemy langchain langchain-community langchain-aws PyAthena -qq

In [36]:
import json
with open("../variables.json", "r") as f:
    variables = json.load(f)

variables

{'accountNumber': '791677101579',
 'regionName': 'us-west-2',
 'collectionArn': 'arn:aws:aoss:us-west-2:791677101579:collection/u99a2f111uq506nobq6l',
 'collectionId': 'u99a2f111uq506nobq6l',
 'vectorIndexName': 'ws-index-',
 'bedrockExecutionRoleArn': 'arn:aws:iam::791677101579:role/advanced-rag-workshop-bedrock_execution_role-us-west-2',
 's3Bucket': '791677101579-us-west-2-advanced-rag-workshop',
 'kbFixedChunk': '2OLAU6UCAW',
 'kbSemanticChunk': 'SCMPE1YU8Y',
 'kbHierarchicalChunk': 'UKZ63LEW5P',
 'kbCustomChunk': 'P55X5UTFYK',
 'sagemakerLLMEndpoint': 'endpoint-llama-3-2-3b-instruct-2025-04-22-19-37-32',
 'guardrail_id': 'a3z8dptcpo5h',
 'guardrail_version': '1'}

# Data Query Assistant

In [37]:
# First, let's import all necessary libraries
import boto3
from sqlalchemy import create_engine
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_aws import ChatBedrock
from langchain_community.agent_toolkits import create_sql_agent, SQLDatabaseToolkit
from langchain.agents import AgentType
from langchain_core.prompts import ChatPromptTemplate
from langchain.callbacks.base import BaseCallbackHandler

# 1: Create a Callback Handler to Capture SQL Queries

In [38]:
class SQLHandler(BaseCallbackHandler):
    """Callback handler to capture the SQL query generated by the agent"""
    def __init__(self):
        self.sql_query = None

    def on_agent_action(self, action, **kwargs):
        if action.tool == "sql_db_query":
            self.sql_query = action.tool_input

# 2: Configure AWS Bedrock Model

In [39]:
# Set your AWS region
region = 'us-west-2'

# ChatBedrock Configuration
model_id = "anthropic.claude-3-5-haiku-20241022-v1:0"
model_kwargs = {
    "max_tokens": 4096,
    "temperature": 0.0,
    "top_k": 250,
    "top_p": 1
}

# 3: Set Up Athena Connection

In [40]:
# Athena Configuration
athena_workgroup = 'primary'
athena_query_result_location = f"s3://{variables['s3Bucket']}/athena-query-results/"
db_name = "retail"

print("athena_query_result_location")
print(athena_query_result_location)
def create_athena_engine(aws_region, athena_workgroup, athena_query_result_location, db_name):
    """Create a connection to AWS Athena"""
    athena_endpoint = f'athena.{aws_region}.amazonaws.com'
    athena_conn_string = (
        f"awsathena+rest://@{athena_endpoint}:443/{db_name}"
        f"?s3_staging_dir={athena_query_result_location}&work_group={athena_workgroup}"
    )
    athena_engine = create_engine(athena_conn_string, echo=False)
    return SQLDatabase(athena_engine)

athena_query_result_location
s3://791677101579-us-west-2-advanced-rag-workshop/athena-query-results/


# 4: Create a Filtered Database Class

In [41]:
# Override the SQLDatabase class to only include the desired tables
class FilteredSQLDatabase(SQLDatabase):
    def get_usable_table_names(self):
        return filtered_tables

# Filter tables to include only desired tables
def get_filtered_tables(db, desired_tables):
    all_tables = db.get_usable_table_names()
    return [table for table in all_tables if table in desired_tables]

# 5: Create Bedrock LLM

In [42]:
def create_bedrock_llm(model_id, model_kwargs):
    """Create a LangChain wrapper for AWS Bedrock"""
    bedrock_runtime = boto3.client(service_name="bedrock-runtime", region_name='us-west-2')
    return ChatBedrock(
        client=bedrock_runtime,
        model_id=model_id,
        model_kwargs=model_kwargs
    )

# 6: Create SQL Agent

In [43]:
def create_agent(db, llm, verbose = False):
    """Create a SQL agent with the database toolkit"""
    sql_toolkit = SQLDatabaseToolkit(llm=llm, db=db)
    agent_kwargs = {
        "handle_parsing_errors": True,
        "handle_sql_errors": True,
        "return_intermediate_steps": True
    }
    return create_sql_agent(
        llm=llm,
        toolkit=sql_toolkit,
        agent_executor_kwargs=agent_kwargs,
        agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        verbose=verbose  # Set to True to see the agent's thought process
    )

# 7: Define the Prompt Template

In [44]:
prompt_template = ChatPromptTemplate.from_messages([
    ("system", """
    You are an expert in Amazon Athena.
    You have access to the live database to query.
    To answer this question, 
        you will first need to get the schema of the relevant tables to see what columns are available.
    Then query the relevant tables in the database to come up with Final Answer.
    Do not assume any values for the data.
    Use [sql_db_list_tables] to get a list of tables in the database.
    Use [sql_db_schema] to the schema for these tables.
    Use [sql_db_query_checker] to validate the SQL query.
    Execute the query using [sql_db_query] tool and observe the output.
    Always provide the explanation and assumptions that you have made to come up with the output.
    For forecasting questions:
    - There won't be any data available for the future dates. So, identify historical data trends.
    - Use appropriate methods to forecast future values based on historical data.
    - Clearly explain the forecasting methodology and results.
    """
    ),
    ("human", "{context}"),
])

# 8: Function to Invoke the Agent

In [45]:
import time
import random

def invoke_agent(agent, question, max_retries=5, base_delay=5, max_delay=60):
    """
    Invoke the agent with a question and return both SQL and output with retry backoff.
    
    Args:
        agent: The agent to invoke
        question: The question to ask
        max_retries: Maximum number of retry attempts
        base_delay: Initial delay in seconds
        max_delay: Maximum delay in seconds
    
    Returns:
        Tuple of (sql_query, output) or (None, error_message) if all retries fail
    """
    handler = SQLHandler()
    prompt = prompt_template.invoke(question)
    
    attempt = 0
    last_exception = None
    
    while attempt <= max_retries:
        try:
            # Attempt to invoke the agent
            response = agent.invoke({"input": prompt}, {"callbacks": [handler]})
            output = response['output']
            sqlquery = handler.sql_query
            
            # If successful, return the results
            print(f"Agent invocation successful on attempt {attempt+1}")
            return sqlquery, output
            
        except Exception as e:
            last_exception = e
            attempt += 1
            
            # Check if we've reached the retry limit
            if attempt > max_retries:
                break
            
            # Check if this is a throttling error by examining the exception message
            is_throttling = "ThrottlingException" in str(e) or "Too many requests" in str(e)
            
            # Calculate delay with exponential backoff and jitter
            if is_throttling:
                # For throttling, use a more aggressive backoff strategy
                delay = min(base_delay * (4 ** (attempt - 1)) + random.uniform(0, 2), max_delay)
                print(f"Throttling detected on attempt {attempt}. Backing off for {delay:.2f} seconds...")
            else:
                # For other errors, use standard exponential backoff
                delay = min(base_delay * (2 ** (attempt - 1)) + random.uniform(0, 1), max_delay)
                print(f"Agent invocation attempt {attempt} failed. Retrying in {delay:.2f} seconds...")
                
            # Wait before retrying
            time.sleep(delay)
    
    # If we've exhausted all retries, provide a graceful message instead of raising an exception
    if "ThrottlingException" in str(last_exception) or "Too many requests" in str(last_exception):
        error_message = "Unable to get a response from the agent due to service throttling. The service is currently experiencing high demand. Please try again in a few minutes or reduce request frequency."
    else:
        error_message = f"Unable to get a response from the agent after {max_retries} attempts. Error: {str(last_exception)}"
    
    print(error_message)
    
    # Return None for SQL query and the error message as output
    return None, error_message

# 9: Put Everything Together

In [46]:
# Create Athena engine
db = create_athena_engine(region, athena_workgroup, athena_query_result_location, db_name)

# List of desired tables
desired_tables = ["retail_order", "retail_returns"]
filtered_tables = get_filtered_tables(db, desired_tables)
filtered_db = FilteredSQLDatabase(db._engine)

# Create the LLM and agent
llm = create_bedrock_llm(model_id, model_kwargs)
agent = create_agent(filtered_db, llm)
print("Available tables:", filtered_tables)

Available tables: ['retail_order', 'retail_returns']


# 10: Example Queries

In [47]:
# Example 1: Simple query
question = "How many transactions are recorded in the order table?"
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
SELECT COUNT(*) AS total_transactions FROM retail_order

Output:
There are 2,500 transactions recorded in the order table (retail_order). 

Explanation:
- I first listed the available tables in the database and found the retail_order table
- I checked the schema of the retail_order table to confirm it contained transaction data
- I used a simple COUNT(*) query to count the total number of rows in the table
- The query returned 2,500, which represents the total number of transactions in the order table

Note: The table schema showed an initial sample of 3 rows, but the actual table properties indicated a total record count of 2,372. However, the direct COUNT(*) query revealed 2,500 transactions, which is the most accurate representation of the current table contents.


In [48]:
# Example 2: Simple query with aggregation
question = "Find the total number of transactions per country."
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
SELECT country, COUNT(*) as total_transactions 
FROM retail_order 
GROUP BY country 
ORDER BY total_transactions DESC 
LIMIT 10

Output:
The total number of transactions per country has been calculated, with the United States having the most transactions at 461, followed by Australia with 196 transactions, and France with 187 transactions.


In [49]:
# Example 3: Simple query with aggregation.
question = "What is the most purchased product in the UK?"
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
SELECT product_name, 
       SUM(quantity) as total_quantity
FROM retail_order
WHERE country = 'United Kingdom'
GROUP BY product_name
ORDER BY total_quantity DESC
LIMIT 10;


Output:
The most purchased product in the UK is the "Nokia Smart Phone" with a total quantity of 15,915 units sold. 

Here's a breakdown of the top 10 most purchased products in the United Kingdom:
1. Nokia Smart Phone (15,915 units)
2. Motorola Smart Phone (7,721 units)
3. Cisco Smart Phone (7,230 units)
4. Brother Fax Machine (5,733 units)
5. Breville Microwave (5,714 units)
6. Hoover Stove (5,678 units)
7. Barricks Conference Table (5,451 units)
8. KitchenAid Refrigerator (5,273 units)
9. Safco Classic Bookcase (4,611 units)
10. Brother Wireless Fax (4,540 units)

The analysis is based on the total quantity of products sold in the United Kingdom from the retail_order table, sorted in descending order to highlight the most purchased items.


In [50]:
# Example 4: JOIN operations
question = "Which country has the most returned orders?"
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
SELECT o.country, COUNT(*) as return_count
FROM retail_returns r
JOIN retail_order o ON r.order_id = o.order_id
GROUP BY o.country
ORDER BY return_count DESC
LIMIT 10;


Output:
The United States has the most returned orders, with 44 returns. This is followed by China with 35 returns, and Mexico with 17 returns. 

Explanation of the analysis:
1. I joined the `retail_returns` and `retail_order` tables using the `order_id` as the linking key.
2. I grouped the results by country and counted the number of returns for each country.
3. I ordered the results in descending order of return count to see which countries have the most returns.
4. The top countries with the most returned orders are:
   - United States: 44 returns
   - China: 35 returns
   - Mexico: 17 returns
   - France: 13 returns
   - United Kingdom: 12 returns

This analysis provides insights into the geographical distribution of product returns in the dataset.


In [51]:
# Example 5: Use Common Table Expressions (CTE)
question = "Which product categories were the most common ones in the top 3 countries with the most returned orders."
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
WITH returned_orders AS (
    SELECT o.country, COUNT(r.order_id) as return_count
    FROM retail_order o
    JOIN retail_returns r ON o.order_id = r.order_id
    GROUP BY o.country
    ORDER BY return_count DESC
    LIMIT 3
),
category_counts AS (
    SELECT 
        o.country, 
        o.category, 
        COUNT(*) as category_return_count,
        RANK() OVER (PARTITION BY o.country ORDER BY COUNT(*) DESC) as category_rank
    FROM retail_order o
    JOIN retail_returns r ON o.order_id = r.order_id
    JOIN returned_orders ro ON o.country = ro.country
    GROUP BY o.country, o.category
)
SELECT 
    country, 
    category, 
    category_return_count
FROM category_counts
WHERE category_rank <= 3
ORDER BY country, category_return_count DESC
LIMIT 10


Output:
In the top 3 countries with the most returned orders (United States, China, and Mexico), Furniture and Technology were the most common product categories for returns. The United

##### You can turn on a "verbose" option for LangChain agents to see the thoughts process in the flow. 

In [52]:
verbose_agent = create_agent(filtered_db, llm, verbose = True)

In [53]:
# Example 6: Complex query using JOIN and LIKE.
question = "Find any customer whose name is Justin and returned a furniture in 2013."

prompted_question = f"""Identify if the QUESTION specified has multiple sub questions. For each subquestion, think step by step and identify the right tables and columns while creating a SQL query.

QUESTION: {question}"""
sql_query, output = invoke_agent(verbose_agent, prompted_question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mI'll break this down step by step:

1. First, I'll list the available tables:

Action: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3mretail_order, retail_returns[0m[32;1m[1;3mI'll break this down step by step:

1. I see two relevant tables: retail_order and retail_returns

2. I'll check the schema of these tables:

Action: sql_db_schema
Action Input: retail_order, retail_returns[0m[33;1m[1;3m
CREATE EXTERNAL TABLE retail_order (
	orders_index BIGINT,
	row_id BIGINT,
	order_id STRING,
	order_date STRING,
	ship_date STRING,
	ship_mode STRING,
	customer_id STRING,
	customer_name STRING,
	segment STRING,
	city STRING,
	state STRING,
	country STRING,
	postal_code FLOAT,
	market STRING,
	region STRING,
	product_id STRING,
	category STRING,
	`sub-category` STRING,
	product_name STRING,
	sales FLOAT,
	quantity BIGINT,
	discount FLOAT,
	profit FLOAT,
	shipping_cost FLOAT,
	order_priority STRING
)
ROW FORMAT SERDE 'o