# SQL Database Agent

This notebook showcases an agent designed to interact with a sql databases. The agent builds off of [SQLDatabaseChain](https://langchain.readthedocs.io/en/latest/modules/chains/examples/sqlite.html) and is designed to answer more general questions about a database, as well as recover from errors.

Note that, as this agent is in active development, all answers might not be correct. Additionally, it is not guaranteed that the agent won't perform DML statements on your database given certain questions. Be careful running it on sensitive data!

This uses the example Chinook database. To set it up follow the instructions on https://database.guide/2-sample-databases-sqlite/, placing the .db file in a notebooks folder at the root of this repository.

## Initialization

In [1]:
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.llms.openai import OpenAI
from langchain.agents import AgentExecutor

In [2]:
db = SQLDatabase.from_uri("sqlite:///../../../../notebooks/Chinook.db")
toolkit = SQLDatabaseToolkit(db=db)

agent_executor = create_sql_agent(
    llm=OpenAI(temperature=0),
    toolkit=toolkit,
    verbose=True
)

In [3]:
from langchain.chat_models.openai import ChatOpenAI
from langchain.agents.chat.base import ChatAgent

agent = ChatAgent.from_chat_model_and_tools(ChatOpenAI(temperature=0), toolkit.get_tools())
agent_executor.agent = agent

## Example: describing a table

In [4]:
agent_executor.run("Describe the play list track table")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: I need to check the schema of the database to see if there is a table called "play list track"
Action:
```
{
  "action": "list_tables_sql_db",
  "action_input": ""
}
```
[0m
Observation: [38;5;200m[1;3mCustomer, Invoice, Track, Artist, Genre, Employee, MediaType, InvoiceLine, Playlist, PlaylistTrack, Album, sales_table[0m
Thought:[32;1m[1;3mThere is a table called "PlaylistTrack" in the database. I need to use the schema_sql_db tool to get the schema and sample rows for this table.
Action:
```
{
  "action": "schema_sql_db",
  "action_input": "PlaylistTrack"
}
```

[0m
Observation: [33;1m[1;3m
CREATE TABLE "PlaylistTrack" (
	"PlaylistId" INTEGER NOT NULL, 
	"TrackId" INTEGER NOT NULL, 
	PRIMARY KEY ("PlaylistId", "TrackId"), 
	FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), 
	FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)

SELECT * FROM 'PlaylistTrack' LIMIT 3;
PlaylistId	TrackId
1	34

'The PlaylistTrack table has two columns: PlaylistId and TrackId. It has a composite primary key consisting of both columns. There are foreign key constraints on both columns referencing the Playlist and Track tables respectively. The sample rows show the first three entries in the table.'

## Example: describing a table, recovering from an error

In this example, the agent tries to search for a table that doesn't exist, but finds the next best result

In [5]:
agent_executor.run("Describe the table for that has song information")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: I need to use the `schema_sql_db` tool to get the schema and sample rows for the table that has song information. But first, I need to know the name of the table.

Action:
```
{
  "action": "list_tables_sql_db",
  "action_input": ""
}
```

[0m
Observation: [38;5;200m[1;3mCustomer, Invoice, Track, Artist, Genre, Employee, MediaType, InvoiceLine, Playlist, PlaylistTrack, Album, sales_table[0m
Thought:[32;1m[1;3mThe table that has song information is likely to be named "Track". I will use the `schema_sql_db` tool to get the schema and sample rows for this table.

Action:
```
{
  "action": "schema_sql_db",
  "action_input": "Track"
}
```

[0m
Observation: [33;1m[1;3m
CREATE TABLE "Track" (
	"TrackId" INTEGER NOT NULL, 
	"Name" NVARCHAR(200) NOT NULL, 
	"AlbumId" INTEGER, 
	"MediaTypeId" INTEGER NOT NULL, 
	"GenreId" INTEGER, 
	"Composer" NVARCHAR(220), 
	"Milliseconds" INTEGER NOT NULL, 
	"Bytes" INTEGER, 
	"Uni

  sample_rows = connection.execute(command)


[32;1m[1;3mThe table "Track" has columns for TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, and UnitPrice. The sample rows show the first three tracks in the table. 

Final Answer: The table that has song information is named "Track" and has columns for TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, and UnitPrice.[0m

[1m> Finished chain.[0m


'The table that has song information is named "Track" and has columns for TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, and UnitPrice.'

## Example: running queries

In [6]:
agent_executor.run("List the total sales per country. Which country's customers spent the most?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mQuestion: List the total sales per country. Which country's customers spent the most?
Thought: I need to retrieve the total sales per country and then find the country with the highest total sales.
Action:
```
{"action": "query_sql_db", "action_input": "SELECT SUM(sales) AS total_sales, country FROM sales_table GROUP BY country ORDER BY total_sales DESC LIMIT 1"}
```
[0m
Observation: [36;1m[1;3m[(900, 'Japan')][0m
Thought:[32;1m[1;3mI need to modify the query to retrieve the country with the highest total sales, not just the first one in the list.
Action:
```
{"action": "query_sql_db", "action_input": "SELECT SUM(sales) AS total_sales, country FROM sales_table GROUP BY country ORDER BY total_sales DESC LIMIT 1"}
```

[0m
Observation: [36;1m[1;3m[(900, 'Japan')][0m
Thought:[32;1m[1;3mI need to modify the query to retrieve the country with the highest total sales, not just the first one in the list.
Action:
```
{"ac

'Japan'

In [7]:
agent_executor.run("Show the total number of tracks in each playlist. The Playlist name should be included in the result.")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: I need to retrieve data from multiple tables, so I need to check the schema of the tables and make sure I have the correct column names to join the tables. Then I can use a SQL query to get the total number of tracks in each playlist.

Action:
```
{
  "action": "schema_sql_db",
  "action_input": "Playlist, PlaylistTrack, Track"
}
```

[0m
Observation: [33;1m[1;3m
CREATE TABLE "Playlist" (
	"PlaylistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("PlaylistId")
)

SELECT * FROM 'Playlist' LIMIT 3;
PlaylistId	Name
1	Music
2	Movies
3	TV Shows


CREATE TABLE "Track" (
	"TrackId" INTEGER NOT NULL, 
	"Name" NVARCHAR(200) NOT NULL, 
	"AlbumId" INTEGER, 
	"MediaTypeId" INTEGER NOT NULL, 
	"GenreId" INTEGER, 
	"Composer" NVARCHAR(220), 
	"Milliseconds" INTEGER NOT NULL, 
	"Bytes" INTEGER, 
	"UnitPrice" NUMERIC(10, 2) NOT NULL, 
	PRIMARY KEY ("TrackId"), 
	FOREIGN KEY("MediaTypeId") REFERENCES "MediaType" ("Medi

"The total number of tracks in each playlist is displayed in the following format: [('90’s Music', 1477), ('Brazilian Music', 39), ('Classical', 75), ('Classical 101 - Deep Cuts', 25), ('Classical 101 - Next Steps', 25), ('Classical 101 - The Basics', 25), ('Grunge', 15), ('Heavy Metal Classic', 26), ('Music', 6580), ('Music Videos', 1), ('On-The-Go 1', 1), ('TV Shows', 426)]."

## Recovering from an error

In this example, the agent is able to recover from an error after initially trying to access an attribute (`Track.ArtistId`) which doesn't exist.

In [8]:
agent_executor.run("Who are the top 3 best selling artists?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mThought: I need to query the database to get the best selling artists. I should use the `query_sql_db` tool for this.

Action:
```
{
  "action": "query_sql_db",
  "action_input": "SELECT artist, SUM(sales) AS total_sales FROM sales GROUP BY artist ORDER BY total_sales DESC LIMIT 3"
}
```

[0m
Observation: [36;1m[1;3mError: (sqlite3.OperationalError) no such table: sales
[SQL: SELECT artist, SUM(sales) AS total_sales FROM sales GROUP BY artist ORDER BY total_sales DESC LIMIT 3]
(Background on this error at: https://sqlalche.me/e/14/e3q8)[0m
Thought:[32;1m[1;3mI need to check if the table "sales" exists in the database before running the query. I should use the `list_tables_sql_db` tool for this.

Action:
```
{
  "action": "list_tables_sql_db",
  "action_input": ""
}
```

[0m
Observation: [38;5;200m[1;3mCustomer, Invoice, Track, Artist, Genre, Employee, MediaType, InvoiceLine, Playlist, PlaylistTrack, Album, sales_tabl

'Iron Maiden, U2, Metallica.'