### **Text to SQL: Natural Language to Athena Query Execution**

### Description:
This application allows users to interact with a database using natural language queries. 

Users can ask questions in plain English, and the system will convert the question into an SQL query. This query is then executed on an Amazon Athena database. Once the query is executed, the results are translated back into natural language, providing an easy-to-understand response. 

This workflow enables users to interact with databases without needing to write complex SQL, making it more accessible for non-technical users.

![Text to SQL](./text_to_sql.png)

In [1]:
import warnings
from datetime import datetime, timedelta, UTC

notebook_start_time = datetime.now(UTC)
warnings.filterwarnings("ignore")


In [2]:
!pip install boto3 sqlalchemy langchain langchain-community langchain-aws PyAthena -qq

# Data Query Assistant

In [3]:
# First, let's import all necessary libraries
import boto3
from sqlalchemy import create_engine
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_aws import ChatBedrock
from langchain_community.agent_toolkits import create_sql_agent, SQLDatabaseToolkit
from langchain.agents import AgentType
from langchain_core.prompts import ChatPromptTemplate
from langchain.callbacks.base import BaseCallbackHandler

# 1: Create a Callback Handler to Capture SQL Queries

In [4]:
class SQLHandler(BaseCallbackHandler):
    """Callback handler to capture the SQL query generated by the agent"""
    def __init__(self):
        self.sql_query = None

    def on_agent_action(self, action, **kwargs):
        if action.tool == "sql_db_query":
            self.sql_query = action.tool_input

# 2: Configure AWS Bedrock Model

In [5]:
# Set your AWS region
region = 'us-west-2'

# ChatBedrock Configuration
model_id = "anthropic.claude-3-5-haiku-20241022-v1:0"
model_kwargs = {
    "max_tokens": 4096,
    "temperature": 0.0,
    "top_k": 250,
    "top_p": 1
}

# 3: Set Up Athena Connection

In [6]:
# Athena Configuration
import os
athena_workgroup = 'primary'

session = boto3.session.Session()
region = session.region_name
account_id = session.client('sts').get_caller_identity()['Account']
athena_query_result_bucket = f"{account_id}-{region}-athena-output"
athena_query_result_location = f"s3://{athena_query_result_bucket}/athena-query-results/"
db_name = "retail"

print("athena_query_result_location")
print(athena_query_result_location)
def create_athena_engine(aws_region, athena_workgroup, athena_query_result_location, db_name):
    """Create a connection to AWS Athena"""
    athena_endpoint = f'athena.{aws_region}.amazonaws.com'
    athena_conn_string = (
        f"awsathena+rest://@{athena_endpoint}:443/{db_name}"
        f"?s3_staging_dir={athena_query_result_location}&work_group={athena_workgroup}"
    )
    athena_engine = create_engine(athena_conn_string, echo=False)
    return SQLDatabase(athena_engine)

athena_query_result_location
s3://270597685972-us-west-2-athena-output/athena-query-results/


# 4: Create Bedrock LLM

In [7]:
def create_bedrock_llm(model_id, model_kwargs):
    """Create a LangChain wrapper for AWS Bedrock"""
    bedrock_runtime = session.client(service_name="bedrock-runtime",
                                   region_name=region)
    return ChatBedrock(
        client=bedrock_runtime,
        model_id=model_id,
        model_kwargs=model_kwargs
    )

# 5: Create SQL Agent

In [8]:
def create_agent(db, llm, verbose = False):
    """Create a SQL agent with the database toolkit"""
    sql_toolkit = SQLDatabaseToolkit(llm=llm, db=db)
    agent_kwargs = {
        "handle_parsing_errors": True,
        "handle_sql_errors": True,
        "return_intermediate_steps": True
    }
    return create_sql_agent(
        llm=llm,
        toolkit=sql_toolkit,
        agent_executor_kwargs=agent_kwargs,
        agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        verbose=verbose  # Set to True to see the agent's thought process
    )

# 6: Define the Prompt Template

In [9]:
prompt_template = ChatPromptTemplate.from_messages([
    ("system", """
    You are an expert in Amazon Athena.
    You have access to the live database to query.
    To answer this question, 
        you will first need to get the schema of the relevant tables to see what columns are available.
    Then query the relevant tables in the database to come up with Final Answer.
    Do not assume any values for the data.
    Use [sql_db_list_tables] to get a list of tables in the database.
    Use [sql_db_schema] to the schema for these tables.
    Use [sql_db_query_checker] to validate the SQL query.
    Execute the query using [sql_db_query] tool and observe the output.
    Always provide the explanation and assumptions that you have made to come up with the output.
    For forecasting questions:
    - There won't be any data available for the future dates. So, identify historical data trends.
    - Use appropriate methods to forecast future values based on historical data.
    - Clearly explain the forecasting methodology and results.
    """
    ),
    ("human", "{context}"),
])

# 7: Function to Invoke the Agent

In [10]:
import time
import random

def invoke_agent(agent, question, max_retries=5, base_delay=5, max_delay=60):
    """
    Invoke the agent with a question and return both SQL and output with retry backoff.
    
    Args:
        agent: The agent to invoke
        question: The question to ask
        max_retries: Maximum number of retry attempts
        base_delay: Initial delay in seconds
        max_delay: Maximum delay in seconds
    
    Returns:
        Tuple of (sql_query, output) or (None, error_message) if all retries fail
    """
    handler = SQLHandler()
    prompt = prompt_template.invoke(question)
    
    attempt = 0
    last_exception = None
    
    while attempt <= max_retries:
        try:
            # Attempt to invoke the agent
            response = agent.invoke({"input": prompt}, {"callbacks": [handler]})
            output = response['output']
            sqlquery = handler.sql_query
            
            # If successful, return the results
            print(f"Agent invocation successful on attempt {attempt+1}")
            return sqlquery, output
            
        except Exception as e:
            last_exception = e
            attempt += 1
            
            # Check if we've reached the retry limit
            if attempt > max_retries:
                break
            
            # Check if this is a throttling error by examining the exception message
            is_throttling = "ThrottlingException" in str(e) or "Too many requests" in str(e)
            
            # Calculate delay with exponential backoff and jitter
            if is_throttling:
                # For throttling, use a more aggressive backoff strategy
                delay = min(base_delay * (4 ** (attempt - 1)) + random.uniform(0, 2), max_delay)
                print(f"Throttling detected on attempt {attempt}. Backing off for {delay:.2f} seconds...")
            else:
                # For other errors, use standard exponential backoff
                delay = min(base_delay * (2 ** (attempt - 1)) + random.uniform(0, 1), max_delay)
                print(f"Agent invocation attempt {attempt} failed. Retrying in {delay:.2f} seconds...")
                
            # Wait before retrying
            time.sleep(delay)
    
    # If we've exhausted all retries, provide a graceful message instead of raising an exception
    if "ThrottlingException" in str(last_exception) or "Too many requests" in str(last_exception):
        error_message = "Unable to get a response from the agent due to service throttling. The service is currently experiencing high demand. Please try again in a few minutes or reduce request frequency."
    else:
        error_message = f"Unable to get a response from the agent after {max_retries} attempts. Error: {str(last_exception)}"
    
    print(error_message)
    
    # Return None for SQL query and the error message as output
    return None, error_message

# 8: Put Everything Together

In [11]:
# Create Athena engine
db = create_athena_engine(region, athena_workgroup, athena_query_result_location, db_name)

# Create the LLM and agent
llm = create_bedrock_llm(model_id, model_kwargs)
agent = create_agent(db, llm)


# 9: Example Queries

In [12]:
# Example 1: Simple query
question = "How many transactions are recorded in the order table?"
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
SELECT COUNT(*) AS total_transactions FROM orders

Output:
There are 2,500 transactions recorded in the order table.

Explanation:
- I first listed the available tables in the database and found an "orders" table
- I checked the schema of the orders table to confirm its structure
- I used a simple SQL COUNT(*) query to count the total number of rows in the orders table
- The query returned 2,500, which represents the total number of transactions in the database

The query was straightforward: `SELECT COUNT(*) AS total_transactions FROM orders` which counts all rows in the orders table, giving us the total number of transactions.


In [13]:
# Example 2: Simple query with aggregation
question = "Find the total number of transactions per country."
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
SELECT country, COUNT(DISTINCT order_id) as total_transactions 
FROM orders 
GROUP BY country 
ORDER BY total_transactions DESC 
LIMIT 10

Output:
Here are the total number of transactions per country, sorted from highest to lowest:

1. United States: 428 transactions
2. Australia: 184 transactions
3. France: 176 transactions
4. China: 159 transactions
5. India: 118 transactions
6. Mexico: 117 transactions
7. United Kingdom: 94 transactions
8. Indonesia: 89 transactions
9. Germany: 84 transactions
10. Brazil: 61 transactions

Explanation:
- I queried the 'orders' table to count the number of unique order IDs for each country.
- Used COUNT(DISTINCT order_id) to ensure each transaction is counted only once.
- Sorted the results in descending order to show countries with the most transactions first.
- Limited the output to the top 10 countries.

The United States leads with the highest number of transactions at 428, which is significantl

In [14]:
# Example 3: Simple query with aggregation.
question = "What is the most purchased product in the UK?"
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
SELECT product_name, SUM(quantity) as total_quantity
FROM orders
WHERE country = 'United Kingdom'
GROUP BY product_name
ORDER BY total_quantity DESC
LIMIT 10;


Output:
The most purchased product in the UK is the "Nokia Smart Phone" with a total quantity of 15,915 units sold. 

Explanation:
- I queried the orders table, filtering for orders from the United Kingdom
- I grouped the results by product name and summed the total quantity
- I ordered the results in descending order of quantity to find the top-selling product
- The Nokia Smart Phone significantly outperforms other products, with nearly twice the quantity of the next most purchased item (Motorola Smart Phone at 7,721 units)

Top 3 most purchased products in the UK:
1. Nokia Smart Phone (15,915 units)
2. Motorola Smart Phone (7,721 units)
3. Cisco Smart Phone (7,230 units)


In [15]:
# Example 4: JOIN operations
question = "Which country has the most returned orders?"
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
SELECT o.country, COUNT(*) as return_count
FROM orders o
JOIN returns r ON o.order_id = r.order_id
GROUP BY o.country
ORDER BY return_count DESC
LIMIT 10;


Output:
The United States has the most returned orders, with 44 returned orders in the dataset. This is followed by China with 35 returned orders, and Mexico with 17 returned orders.

Explanation of the analysis:
1. I joined the "orders" and "returns" tables using the order_id as the common key
2. I grouped the results by country and counted the number of returns
3. I ordered the results in descending order of return count
4. The query returned the top 10 countries by number of returns
5. The United States clearly leads with the highest number of returned orders at 44

The data suggests that the United States has the most significant return volume in this retail dataset, which could be due to various factors such as:
- Larger market size
- More lenient return policies
- Higher cus

In [35]:
# Example 5: Use Common Table Expressions (CTE)
question = "Which product categories were the most common ones in the top 3 countries with the most returned orders."
sql_query, output = invoke_agent(agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)

Agent invocation successful on attempt 1
SQL Query:
WITH returned_orders AS (
    SELECT o.country, COUNT(r.order_id) as return_count
    FROM orders o
    JOIN returns r ON o.order_id = r.order_id
    GROUP BY o.country
    ORDER BY return_count DESC
    LIMIT 3
),
category_counts AS (
    SELECT 
        o.country, 
        o.category, 
        COUNT(r.order_id) as category_return_count,
        RANK() OVER (PARTITION BY o.country ORDER BY COUNT(r.order_id) DESC) as category_rank
    FROM orders o
    JOIN returns r ON o.order_id = r.order_id
    JOIN returned_orders ro ON o.country = ro.country
    GROUP BY o.country, o.category
)
SELECT 
    country, 
    category, 
    category_return_count
FROM category_counts
WHERE category_rank <= 3
ORDER BY country, category_return_count DESC;


Output:
In the top 3 countries with the most returned orders (China, Mexico, and United States), Furniture and Technology were the most common product categories for returns. Specifically:
- In China, 

#### In ordr to see the thoughts process that SQLAgent created, turn on "verbose" option for the agent

In [16]:
verbose_agent = create_agent(db, llm, verbose=True)

In [17]:
# Example 6: Complex query using JOIN and LIKE.
question = "Find any customer whose name is Justin and returned a furniture in 2013."

question = f"""Identify if the QUESTION specified has multiple sub questions. For each subquestion, think step by step and identify the right tables and columns while creating a SQL query.

QUESTION: {question}"""
sql_query, output = invoke_agent(verbose_agent, question)

print("SQL Query:")
print(sql_query)
print("\nOutput:")
print(output)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3mI'll break this down step by step:

1. First, I'll list the available tables:

Action: sql_db_list_tables
Action Input: [0m[38;5;200m[1;3morders, returns[0m[32;1m[1;3mI'll break this down step by step:

1. First, I'll check the schema of the 'orders' and 'returns' tables to understand their structure:

Action: sql_db_schema
Action Input: orders, returns[0m[33;1m[1;3m
CREATE EXTERNAL TABLE orders (
	row_id BIGINT,
	order_id STRING,
	order_date TIMESTAMP,
	ship_date TIMESTAMP,
	ship_mode STRING,
	customer_id STRING,
	customer_name STRING,
	segment STRING,
	city STRING,
	state STRING,
	country STRING,
	postal_code FLOAT,
	market STRING,
	region STRING,
	product_id STRING,
	category STRING,
	`sub-category` STRING,
	product_name STRING,
	sales FLOAT,
	quantity BIGINT,
	discount FLOAT,
	profit FLOAT,
	shipping_cost FLOAT,
	order_priority STRING
)
ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
WI

### Cost Summary for Running This Notebook
In this notebook, we have used a LLM for text generation only. Please note that agentic flow used the LLM not only for the final text generation. The model was used to run a chain of thoughts.

In [18]:
# Mark end of query executions here:
notebook_end_time = datetime.now(UTC)

In [19]:
from cost_analysis_helper import get_bedrock_token_based_cost

print(notebook_start_time, notebook_end_time)

inference_cost = get_bedrock_token_based_cost(model_id, notebook_start_time, notebook_end_time)

2025-05-03 02:55:03.609926+00:00 2025-05-03 02:58:41.320993+00:00


In [20]:
inference_cost

 'model_id': 'anthropic.claude-3-5-haiku-20241022-v1:0',
 'start_time': '2025-05-03T02:55:03.609926+00:00',
 'end_time': '2025-05-03T02:58:41.320993+00:00',
 'duration in minutes': 3.6285177833333337,
 'input_tokens': 46929,
 'output_tokens': 2831,
 'invocation_count': 30,
 'per million input token costs': 0.8,
 'per million output token costs': 4.0,
 'input token costs': 0.0375432,
 'output token costs': 0.011324,
 'total token costs': 0.0488672,
 'average token costs per invocation': 0.0016289066666666666,
 'token costs per MILLION such invocations': 1628.9066666666665}