# SQL Agent

Converting Natural Language questions to SQL and get responses from db in natural language

In [5]:
### Loading Chinook db 

import os
import requests

url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"

response = requests.get(url)

if response.status_code == 200:
    # Open a local file in binary write mode
    with open("Chinook.db", "wb") as file:
        # Write the content of the response (the file) to the local file
        file.write(response.content)
    print("File downloaded and saved as Chinook.db")
else:
    print(f"Failed to download the file. Status code: {response.status_code}")

File downloaded and saved as Chinook.db


In [6]:
import sqlite3

# Connect to the database
connection = sqlite3.connect("Chinook.db")
cursor = connection.cursor()

# List all tables in the database
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()

print("Tables in the database:")
for table in tables:
    print(table[0])

# For each table, fetch and display sample records
for table in tables:
    print(f"\nSample records from {table[0]}:")
    cursor.execute(f"SELECT * FROM {table[0]} LIMIT 5;")
    rows = cursor.fetchall()

    for row in rows:
        print(row)

# Close the connection
connection.close()


Tables in the database:
Album
Artist
Customer
Employee
Genre
Invoice
InvoiceLine
MediaType
Playlist
PlaylistTrack
Track

Sample records from Album:
(1, 'For Those About To Rock We Salute You', 1)
(2, 'Balls to the Wall', 2)
(3, 'Restless and Wild', 2)
(4, 'Let There Be Rock', 1)
(5, 'Big Ones', 3)

Sample records from Artist:
(1, 'AC/DC')
(2, 'Accept')
(3, 'Aerosmith')
(4, 'Alanis Morissette')
(5, 'Alice In Chains')

Sample records from Customer:
(1, 'Luís', 'Gonçalves', 'Embraer - Empresa Brasileira de Aeronáutica S.A.', 'Av. Brigadeiro Faria Lima, 2170', 'São José dos Campos', 'SP', 'Brazil', '12227-000', '+55 (12) 3923-5555', '+55 (12) 3923-5566', 'luisg@embraer.com.br', 3)
(2, 'Leonie', 'Köhler', None, 'Theodor-Heuss-Straße 34', 'Stuttgart', None, 'Germany', '70174', '+49 0711 2842222', None, 'leonekohler@surfeu.de', 5)
(3, 'François', 'Tremblay', None, '1498 rue Bélanger', 'Montréal', 'QC', 'Canada', 'H2G 1A7', '+1 (514) 721-4711', None, 'ftremblay@gmail.com', 3)
(4, 'Bjørn', 'Han

### Import Langchain modules

In [7]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_types import AgentType
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from dotenv import load_dotenv

load_dotenv('.env')

True

In [9]:
# Define LLM Model using Azure Chat OpenAI

def get_llm(
    azure_deployment="gpt-4o",
    api_version="2024-06-01",
    temperature=0,
    timeout=600,
    max_tokens=4096,
    max_retries=1,
):
    return AzureChatOpenAI(
        azure_deployment=azure_deployment,
        api_version=api_version,
        temperature=temperature,
        timeout=timeout,
        max_tokens=max_tokens,
        max_retries=max_retries,
    )


llm = get_llm()

# test the model
llm.invoke("Tell me a joke about SQL")

###  Prompt

SQL_PREFIX = """You are an agent designed to interact with a SQL database.
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

If the question does not seem related to the database, just return "I don't know" as the answer.
"""

SQL_SUFFIX = """Begin!

Question: {input}
Thought: I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.
{agent_scratchpad}"""

SQL_FUNCTIONS_SUFFIX = """I should look at the tables in the database to see what I can query.  Then I should query the schema of the most relevant tables.""" 



 
 
References : https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/agent_toolkits/sql/prompt.py


In [5]:
# Set up the SQL database connection
db = SQLDatabase.from_uri('sqlite:///Chinook.db')

In [6]:
# Set up the SQL database toolkit
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

In [7]:
# Create the SQL agent executor
sql_agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    return_intermediate_steps=True
)

In [8]:
example_query = "Which countries have the highest sales?"

# sql_agent_executor.invoke({"input": example_query})

events = sql_agent_executor.stream(
    {"input": example_query})

for event in events:
    event["messages"][-1].pretty_print()



[1m> Entering new SQL Agent Executor chain...[0m

To answer this question, I need to find out which tables in the database contain information about sales and countries.

Action: sql_db_list_tables
Action Input:
[32;1m[1;3mTo answer this question, I need to find out which tables in the database contain information about sales and countries.

Action: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track

Based on the tables listed, it seems that the Invoice table might contain information about sales, and the Customer table likely contains information about countries since customers would have billing addresses or country information.

Action: sql_db_schema
Action Input: Invoice, Customer
[32;1m[1;3mBased on the tables listed, it seems that the Invoice table might contain information about sales, and the Customer table likely contains information about countries since customers would have billing addresses o

_

### SQL Database Toolkit

1. **sql_db_list_tables** : 
    List the tables in the database
    - Input  : list of all tables
    - Output : list of relevant tables

2. **sql_db_schema**
    Get schema for relevant tables
    - Input  : list of relevant tables
    - Output : Schema for relevant tables

3. **sql_db_query_checker**
    Generate SQL query based using relevant table schemas
    - Input  : Schema for relevant tables
    - Output : Valid SQL query

4. **sql_db_query**
    Result of SQL query
    - Input  : Valid SQL query
    - Output : Query output data


In [9]:
example_query = "Who is the oldest employee?"

events = sql_agent_executor.stream(
    {"input": example_query})

for event in events:
    event["messages"][-1].pretty_print()



[1m> Entering new SQL Agent Executor chain...[0m

I need to find the table that contains employee information, including their birthdates or ages, to determine who the oldest employee is.
Action: sql_db_list_tables
Action Input:
[32;1m[1;3mI need to find the table that contains employee information, including their birthdates or ages, to determine who the oldest employee is.
Action: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track

The Employee table likely contains the information needed to determine the oldest employee.
Action: sql_db_schema
Action Input: Employee
[32;1m[1;3mThe Employee table likely contains the information needed to determine the oldest employee.
Action: sql_db_schema
Action Input: Employee[0m[33;1m[1;3m
CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30), 
	"ReportsTo" INTEGER,

In [10]:

example_query = "Which countries have the highest sales?"

events = sql_agent_executor.stream(
    {"input": example_query})

for event in events:
    event["messages"][-1].pretty_print()



[1m> Entering new SQL Agent Executor chain...[0m

To answer this question, I need to find out which tables in the database contain information about sales and countries. First, I'll list all the tables in the database.

Action: sql_db_list_tables
Action Input:
[32;1m[1;3mTo answer this question, I need to find out which tables in the database contain information about sales and countries. First, I'll list all the tables in the database.

Action: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track

Based on the tables listed, it seems that the "Invoice" table might contain information about sales, and the "Customer" table could contain information about countries since customers are likely associated with sales and would have country information. To confirm this, I should check the schema of both "Invoice" and "Customer" tables.

Action: sql_db_schema
Action Input: Invoice, Customer
[32;1m[1;3mBased on the

In [11]:
example_query = "What is the avg length of tracks in each genre?"

events = sql_agent_executor.stream(
    {"input": example_query})

for event in events:
    event["messages"][-1].pretty_print()



[1m> Entering new SQL Agent Executor chain...[0m

First, I need to find out what tables are available in the database to determine where the information about tracks and genres is stored.

Action: sql_db_list_tables
Action Input:
[32;1m[1;3mFirst, I need to find out what tables are available in the database to determine where the information about tracks and genres is stored.

Action: sql_db_list_tables

Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track

Given the tables listed, it seems likely that the information about tracks is stored in the "Track" table and information about genres is stored in the "Genre" table. To calculate the average length of tracks in each genre, I'll need to join these tables on a common field, probably a genre identifier, and then group the results by genre to calculate the average track length.

Action: sql_db_schema
Action Input: Track,Genre
[32;1m[1;3mGiven the tables listed, it seems likel