# Agent

LangChain has an SQL Agent which provides a more flexible way of interacting with SQL Databases than the `SQLDatabaseChain`.

The main advantages of using the SQL Agent are:

- It can answer questions based on the databases' schema as well as on the databases' content (like describing a specific table)
- It can recover from errors by running a generated query, catching the traceback and regenerating it correctly

To initialize the agent, we use `create_sql_agent` function. 

This agent contains the `SQLDatabaseToolkit` which contains tools to: 

* Create and execute queries
* Check query syntax
* Retrieve table descriptions
* ... and more

## Setup

First, get required packages and set environment variables:

In [2]:
%pip install --upgrade --quiet  langchain langchain-community langchain-experimental langchain-openai

We default to OpenAI models in this guide, but you can swap them out for the model provider of your choice.

In [None]:
import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass()

# Uncomment the below to use LangSmith. Not required.
# os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()
# os.environ["LANGCHAIN_TRACING_V2"] = "true"

The below example will use a SQLite connection with Chinook database. Follow [these installation steps](https://database.guide/2-sample-databases-sqlite/) to create `Chinook.db` in the same directory as this notebook:

* Save [this file](https://raw.githubusercontent.com/lerocha/chinook-database/master/ChinookDatabase/DataSources/Chinook_Sqlite.sql) as `Chinook_Sqlite.sql`
* Run `sqlite3 Chinook.db`
* Run `.read Chinook_Sqlite.sql`
* Test `SELECT * FROM Artist LIMIT 10;`

Now, `Chinhook.db` is in our directory and we can interface with it using the SQLAlchemy-driven `SQLDatabase` class:

In [6]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [14]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_types import AgentType
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
agent_executor = create_sql_agent(
    llm=llm, toolkit=toolkit, verbose=True, agent_type=AgentType.OPENAI_FUNCTIONS
)

### Agent task example #1 - Running queries


In [15]:
agent_executor.run(
    "List the total sales per country. Which country's customers spent the most?"
)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m
Thought:[32;1m[1;3m I should query the schema of the Invoice and Customer tables.
Action: sql_db_schema
Action Input: Invoice, Customer[0m
Observation: [33;1m[1;3m
CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Countr

'The country with the highest total sales is the USA, with a total of $523.06.'

Looking at the [LangSmith trace](https://smith.langchain.com/public/a86dbe17-5782-4020-bce6-2de85343168a/r), we can see:

* The agent is using a ReAct style prompt
* First, it will look at the tables: `Action: sql_db_list_tables` using tool `sql_db_list_tables`
* Given the tables as an observation, it `thinks` and then determinates the next `action`:

```
Observation: Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Thought: I should query the schema of the Invoice and Customer tables.
Action: sql_db_schema
Action Input: Invoice, Customer
```

* It then formulates the query using the schema from tool `sql_db_schema`

```
Thought: I should query the total sales per country.
Action: sql_db_query
Action Input: SELECT Country, SUM(Total) AS TotalSales FROM Invoice INNER JOIN Customer ON Invoice.CustomerId = Customer.CustomerId GROUP BY Country ORDER BY TotalSales DESC LIMIT 10
```

* It finally executes the generated query using tool `sql_db_query`

![sql_usecase.png](../../../static/img/SQLDatabaseToolkit.png)

### Agent task example #2 - Describing a Table

In [14]:
agent_executor.run("Describe the playlisttrack table")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m
Thought:[32;1m[1;3m The PlaylistTrack table is the most relevant to the question.
Action: sql_db_schema
Action Input: PlaylistTrack[0m
Observation: [33;1m[1;3m
CREATE TABLE "PlaylistTrack" (
	"PlaylistId" INTEGER NOT NULL, 
	"TrackId" INTEGER NOT NULL, 
	PRIMARY KEY ("PlaylistId", "TrackId"), 
	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
	FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)

/*
3 rows from PlaylistTrack table:
PlaylistId	TrackId
1	3402
1	3389
1	3390
*/[0m
Thought:[32;1m[1;3m I now know the final answer
Final Answer: The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and form a primary key. It also has two foreign keys, one to the Track table and one to 

'The PlaylistTrack table contains two columns, PlaylistId and TrackId, which are both integers and form a primary key. It also has two foreign keys, one to the Track table and one to the Playlist table.'

### Extending the SQL Toolkit

Although the out-of-the-box SQL Toolkit contains the necessary tools to start working on a database, it is often the case that some extra tools may be useful for extending the agent's capabilities. This is particularly useful when trying to use **domain specific knowledge** in the solution, in order to improve its overall performance.

Some examples include:

- Including dynamic few shot examples
- Finding misspellings in proper nouns to use as column filters

We can create separate tools which tackle these specific use cases and include them as a complement to the standard SQL Toolkit. Let's see how to include these two custom tools.

#### Including dynamic few-shot examples

In order to include dynamic few-shot examples, we need a custom **Retriever Tool** that handles the vector database in order to retrieve the examples that are semantically similar to the user’s question.

Let's start by creating a dictionary with some examples: 

In [1]:
few_shots = {
    "List all artists.": "SELECT * FROM artists;",
    "Find all albums for the artist 'AC/DC'.": "SELECT * FROM albums WHERE ArtistId = (SELECT ArtistId FROM artists WHERE Name = 'AC/DC');",
    "List all tracks in the 'Rock' genre.": "SELECT * FROM tracks WHERE GenreId = (SELECT GenreId FROM genres WHERE Name = 'Rock');",
    "Find the total duration of all tracks.": "SELECT SUM(Milliseconds) FROM tracks;",
    "List all customers from Canada.": "SELECT * FROM customers WHERE Country = 'Canada';",
    "How many tracks are there in the album with ID 5?": "SELECT COUNT(*) FROM tracks WHERE AlbumId = 5;",
    "Find the total number of invoices.": "SELECT COUNT(*) FROM invoices;",
    "List all tracks that are longer than 5 minutes.": "SELECT * FROM tracks WHERE Milliseconds > 300000;",
    "Who are the top 5 customers by total purchase?": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM invoices GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
    "Which albums are from the year 2000?": "SELECT * FROM albums WHERE strftime('%Y', ReleaseDate) = '2000';",
    "How many employees are there": 'SELECT COUNT(*) FROM "employee"',
}

We can then create a retriever using the list of questions, assigning the target SQL query as metadata:

In [24]:
from langchain.schema import Document
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

embeddings = OpenAIEmbeddings()

few_shot_docs = [
    Document(page_content=question, metadata={"sql_query": few_shots[question]})
    for question in few_shots.keys()
]
vector_db = FAISS.from_documents(few_shot_docs, embeddings)
retriever = vector_db.as_retriever()

Now we can create our own custom tool and append it as a new tool in the `create_sql_agent` function:

In [5]:
from langchain_community.agent_toolkits import create_retriever_tool

tool_description = """
This tool will help you understand similar examples to adapt them to the user question.
Input to this tool should be the user question.
"""

retriever_tool = create_retriever_tool(
    retriever, name="sql_get_similar_examples", description=tool_description
)
custom_tool_list = [retriever_tool]

Now we can create the agent, adjusting the standard SQL Agent suffix to consider our use case. Although the most straightforward way to handle this would be to include it just in the tool description, this is often not enough and we need to specify it in the agent prompt using the `suffix` argument in the constructor.

In [22]:
from langchain.agents import AgentType, create_sql_agent
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model_name="gpt-4", temperature=0)

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

custom_suffix = """
I should first get the similar examples I know.
If the examples are enough to construct the query, I can build it.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables
"""

agent = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.OPENAI_FUNCTIONS,
    extra_tools=custom_tool_list,
    suffix=custom_suffix,
)

Let's try it out:

In [23]:
agent.run("How many employees do we have?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_get_similar_examples` with `How many employees do we have?`


[0m[33;1m[1;3m[Document(page_content='How many employees are there', metadata={'sql_query': 'SELECT COUNT(*) FROM "employee"'}), Document(page_content='Find the total number of invoices.', metadata={'sql_query': 'SELECT COUNT(*) FROM invoices;'})][0m[32;1m[1;3m
Invoking: `sql_db_query_checker` with `SELECT COUNT(*) FROM employee`
responded: {content}

[0m[36;1m[1;3mSELECT COUNT(*) FROM employee[0m[32;1m[1;3m
Invoking: `sql_db_query` with `SELECT COUNT(*) FROM employee`


[0m[36;1m[1;3m[(8,)][0m[32;1m[1;3mWe have 8 employees.[0m

[1m> Finished chain.[0m


'We have 8 employees.'

As we can see, the agent first used the `sql_get_similar_examples` tool in order to retrieve similar examples. As the question was very similar to other few shot examples, the agent **didn't need to use any other tool** from the standard Toolkit, thus **saving time and tokens**.

#### Finding and correcting misspellings for proper nouns

In order to filter columns that contain proper nouns such as addresses, song names or artists, we first need to double-check the spelling in order to filter the data correctly. 

We can achieve this by creating a vector store using all the distinct proper nouns that exist in the database. We can then have the agent query that vector store each time the user includes a proper noun in their question, to find the correct spelling for that word. In this way, the agent can make sure it understands which entity the user is referring to before building the target query.

Let's follow a similar approach to the few shots, but without metadata: just embedding the proper nouns and then querying to get the most similar one to the misspelled user question.

First we need the unique values for each entity we want, for which we define a function that parses the result into a list of elements:

In [37]:
import ast
import re


def run_query_save_results(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return res


artists = run_query_save_results(db, "SELECT Name FROM Artist")
albums = run_query_save_results(db, "SELECT Title FROM Album")

Now we can proceed with creating the custom **retriever tool** and the final agent:

In [51]:
from langchain_community.agent_toolkits import create_retriever_tool
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings

texts = artists + albums

embeddings = OpenAIEmbeddings()
vector_db = FAISS.from_texts(texts, embeddings)
retriever = vector_db.as_retriever()

retriever_tool = create_retriever_tool(
    retriever,
    name="name_search",
    description="use to learn how a piece of data is actually written, can be from names, surnames addresses etc",
)

custom_tool_list = [retriever_tool]

In [54]:
from langchain.agents import AgentType, create_sql_agent
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI

# db = SQLDatabase.from_uri("sqlite:///Chinook.db")
llm = ChatOpenAI(model_name="gpt-4", temperature=0)

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

custom_suffix = """
If a user asks for me to filter based on proper nouns, I should first check the spelling using the name_search tool.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables
"""

agent = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.OPENAI_FUNCTIONS,
    extra_tools=custom_tool_list,
    suffix=custom_suffix,
)

Let's try it out:

In [55]:
agent.run("How many albums does alis in pains have?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `name_search` with `alis in pains`


[0m[33;1m[1;3m[Document(page_content='House of Pain', metadata={}), Document(page_content='Alice In Chains', metadata={}), Document(page_content='Aisha Duo', metadata={}), Document(page_content='House Of Pain', metadata={})][0m[32;1m[1;3m
Invoking: `sql_db_list_tables` with ``
responded: {content}

[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `Album, Artist`
responded: {content}

[0m[33;1m[1;3m
CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE

'Alice In Chains has 1 album in the database.'

As we can see, the agent used the `name_search` tool in order to check how to correctly query the database for this specific artist.