In [1]:
from dotenv import load_dotenv
load_dotenv()
import openai 
import os

openai.api_key=os.getenv('OPENAI_API_KEY')

In [2]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook1.db")
print(db.dialect)
print(db.get_usable_table_names())


sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


In [3]:
db.run("SELECT * FROM Artist LIMIT 10;")

"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [4]:
db.run("SELECT * FROM Customer LIMIT 10;")

"[(1, 'Luís', 'Gonçalves', 'Embraer - Empresa Brasileira de Aeronáutica S.A.', 'Av. Brigadeiro Faria Lima, 2170', 'São José dos Campos', 'SP', 'Brazil', '12227-000', '+55 (12) 3923-5555', '+55 (12) 3923-5566', 'luisg@embraer.com.br', 3), (2, 'Leonie', 'Köhler', None, 'Theodor-Heuss-Straße 34', 'Stuttgart', None, 'Germany', '70174', '+49 0711 2842222', None, 'leonekohler@surfeu.de', 5), (3, 'François', 'Tremblay', None, '1498 rue Bélanger', 'Montréal', 'QC', 'Canada', 'H2G 1A7', '+1 (514) 721-4711', None, 'ftremblay@gmail.com', 3), (4, 'Bjørn', 'Hansen', None, 'Ullevålsveien 14', 'Oslo', None, 'Norway', '0171', '+47 22 44 22 22', None, 'bjorn.hansen@yahoo.no', 4), (5, 'František', 'Wichterlová', 'JetBrains s.r.o.', 'Klanova 9/506', 'Prague', None, 'Czech Republic', '14700', '+420 2 4172 5555', '+420 2 4172 5555', 'frantisekw@jetbrains.com', 4), (6, 'Helena', 'Holý', None, 'Rilská 3174/6', 'Prague', None, 'Czech Republic', '14300', '+420 2 4177 0449', None, 'hholy@gmail.com', 5), (7, 'As

In [5]:
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

In [6]:
def load_llm():
    from dotenv import load_dotenv
    load_dotenv()
    from langchain_fireworks import ChatFireworks
    api_key = os.getenv('FIREWORKS_API_KEY')
    if not api_key:
        raise ValueError("FIREWORKS_API_KEY not found. Please set it in the .env file.")
    os.environ["FIREWORKS_API_KEY"] = api_key
    llm = ChatFireworks(model="accounts/fireworks/models/llama-v3p1-70b-instruct")
    return llm

In [7]:
firework_llm=load_llm()

In [8]:
firework_agent_executor = create_sql_agent(firework_llm, db=db, agent_type="openai-tools", verbose=True)

In [9]:
firework_agent_executor.invoke(
    "List the total sales per country. Which country's customers spent the most?"
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{'tool_input': ''}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Customer, Invoice'}`


[0m[33;1m[1;3m
CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId	FirstName	LastName	Company	Address	City	State	Country	PostalCode	Phone	Fax	Email	SupportRepId
1	Luís	Gonçalv

{'input': "List the total sales per country. Which country's customers spent the most?",
 'output': 'The country whose customers spent the most is the USA, with a total of $523.06.'}

In [10]:
firework_agent_executor.invoke(
    "Give me the name of album of artist Aerosmith."
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{'tool_input': ''}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Album, Artist'}`


[0m[33;1m[1;3m
CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/[0m[32;1m[1;3m
Invoking: `sql_db_query_checker` with `{'query': 'SELECT T1.Title FROM Album AS T1 INNER JOIN Ar

{'input': 'Give me the name of album of artist Aerosmith.',
 'output': 'I apologize, but I could not find the answer to your question in the database.'}

In [None]:
agent_executor.invoke(
    "Give me the name of album of artist Aerosmith."
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Album, Artist'}`


[0m[33;1m[1;3m
CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/[0m[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT Title FROM Album WHERE ArtistId = 3'}`


[0m[36;1m[1;3m[('Bi

{'input': 'Give me the name of album of artist Aerosmith.',
 'output': 'The name of the album by the artist Aerosmith is "Big Ones".'}

In [12]:
while True:
    question=input("Enter a question:")
    if question=='q':
        break
    firework_agent_executor.invoke(question)
    



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{'tool_input': ''}`


[0m[38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'Artist, Album'}`


[0m[33;1m[1;3m
CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/[0m[32;1m[1;3m
Invoking: `sql_db_query_checker` with `{'query': 'SELECT T2.Name FROM Album AS T1 INNER JOIN Art

In [13]:
# CSV agent

In [14]:
import pandas as pd
df=pd.read_csv('100 Sales Records.csv')
df.head()

Unnamed: 0,Region,Country,Item Type,Sales Channel,Order Priority,Order Date,Order ID,Ship Date,Units Sold,Unit Price,Unit Cost,Total Revenue,Total Cost,Total Profit
0,Australia and Oceania,Tuvalu,Baby Food,Offline,H,5/28/2010,669165933,6/27/2010,9925,255.28,159.42,2533654.0,1582243.5,951410.5
1,Central America and the Caribbean,Grenada,Cereal,Online,C,8/22/2012,963881480,9/15/2012,2804,205.7,117.11,576782.8,328376.44,248406.36
2,Europe,Russia,Office Supplies,Offline,L,5/2/2014,341417157,5/8/2014,1779,651.21,524.96,1158502.59,933903.84,224598.75
3,Sub-Saharan Africa,Sao Tome and Principe,Fruits,Online,C,6/20/2014,514321792,7/5/2014,8102,9.33,6.92,75591.66,56065.84,19525.82
4,Sub-Saharan Africa,Rwanda,Office Supplies,Offline,L,2/1/2013,115456712,2/6/2013,5062,651.21,524.96,3296425.02,2657347.52,639077.5


In [15]:
df.columns

Index(['Region', 'Country', 'Item Type', 'Sales Channel', 'Order Priority',
       'Order Date', 'Order ID', 'Ship Date', 'Units Sold', 'Unit Price',
       'Unit Cost', 'Total Revenue', 'Total Cost', 'Total Profit'],
      dtype='object')

In [16]:
from langchain_community.utilities import SQLDatabase
from sqlalchemy import create_engine

engine = create_engine("sqlite:///sales.db")
df.to_sql("sales", engine, index=False)

100

In [17]:
db = SQLDatabase(engine=engine)
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM sales WHERE Country=='Grenada';")

sqlite
['sales']


"[('Central America and the Caribbean', 'Grenada', 'Cereal', 'Online', 'C', '8/22/2012', 963881480, '9/15/2012', 2804, 205.7, 117.11, 576782.8, 328376.44, 248406.36)]"

In [18]:
df.head()

Unnamed: 0,Region,Country,Item Type,Sales Channel,Order Priority,Order Date,Order ID,Ship Date,Units Sold,Unit Price,Unit Cost,Total Revenue,Total Cost,Total Profit
0,Australia and Oceania,Tuvalu,Baby Food,Offline,H,5/28/2010,669165933,6/27/2010,9925,255.28,159.42,2533654.0,1582243.5,951410.5
1,Central America and the Caribbean,Grenada,Cereal,Online,C,8/22/2012,963881480,9/15/2012,2804,205.7,117.11,576782.8,328376.44,248406.36
2,Europe,Russia,Office Supplies,Offline,L,5/2/2014,341417157,5/8/2014,1779,651.21,524.96,1158502.59,933903.84,224598.75
3,Sub-Saharan Africa,Sao Tome and Principe,Fruits,Online,C,6/20/2014,514321792,7/5/2014,8102,9.33,6.92,75591.66,56065.84,19525.82
4,Sub-Saharan Africa,Rwanda,Office Supplies,Offline,L,2/1/2013,115456712,2/6/2013,5062,651.21,524.96,3296425.02,2657347.52,639077.5


In [23]:
firework_agent_executor = create_sql_agent(firework_llm, db=db, agent_type="openai-tools", verbose=True)
firework_agent_executor.invoke({"input": "which country has the highest total revenue?"})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{'tool_input': ''}`


[0m[38;5;200m[1;3msales[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'sales'}`


[0m[33;1m[1;3m
CREATE TABLE sales (
	"Region" TEXT, 
	"Country" TEXT, 
	"Item Type" TEXT, 
	"Sales Channel" TEXT, 
	"Order Priority" TEXT, 
	"Order Date" TEXT, 
	"Order ID" BIGINT, 
	"Ship Date" TEXT, 
	"Units Sold" BIGINT, 
	"Unit Price" FLOAT, 
	"Unit Cost" FLOAT, 
	"Total Revenue" FLOAT, 
	"Total Cost" FLOAT, 
	"Total Profit" FLOAT
)

/*
3 rows from sales table:
Region	Country	Item Type	Sales Channel	Order Priority	Order Date	Order ID	Ship Date	Units Sold	Unit Price	Unit Cost	Total Revenue	Total Cost	Total Profit
Australia and Oceania	Tuvalu	Baby Food	Offline	H	5/28/2010	669165933	6/27/2010	9925	255.28	159.42	2533654.0	1582243.5	951410.5
Central America and the Caribbean	Grenada	Cereal	Online	C	8/22/2012	963881480	9/15/2012	2804	205.7	117.11	576782.8	328

{'input': 'which country has the highest total revenue?',
 'output': 'The country with the highest total revenue is Honduras.'}

In [25]:
df['Total Revenue'].max()

5997054.98

In [28]:
df[df['Country']=='Honduras']

Unnamed: 0,Region,Country,Item Type,Sales Channel,Order Priority,Order Date,Order ID,Ship Date,Units Sold,Unit Price,Unit Cost,Total Revenue,Total Cost,Total Profit
13,Central America and the Caribbean,Honduras,Household,Offline,H,2/8/2017,522840487,2/13/2017,8974,668.27,502.54,5997054.98,4509793.96,1487261.02
22,Central America and the Caribbean,Honduras,Snacks,Online,L,6/30/2016,795490682,7/26/2016,2225,152.58,97.44,339490.5,216804.0,122686.5


In [29]:
firework_agent_executor.invoke({"input": "How many sales channel are there?"})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{'tool_input': ''}`


[0m[38;5;200m[1;3msales[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'sales'}`


[0m[33;1m[1;3m
CREATE TABLE sales (
	"Region" TEXT, 
	"Country" TEXT, 
	"Item Type" TEXT, 
	"Sales Channel" TEXT, 
	"Order Priority" TEXT, 
	"Order Date" TEXT, 
	"Order ID" BIGINT, 
	"Ship Date" TEXT, 
	"Units Sold" BIGINT, 
	"Unit Price" FLOAT, 
	"Unit Cost" FLOAT, 
	"Total Revenue" FLOAT, 
	"Total Cost" FLOAT, 
	"Total Profit" FLOAT
)

/*
3 rows from sales table:
Region	Country	Item Type	Sales Channel	Order Priority	Order Date	Order ID	Ship Date	Units Sold	Unit Price	Unit Cost	Total Revenue	Total Cost	Total Profit
Australia and Oceania	Tuvalu	Baby Food	Offline	H	5/28/2010	669165933	6/27/2010	9925	255.28	159.42	2533654.0	1582243.5	951410.5
Central America and the Caribbean	Grenada	Cereal	Online	C	8/22/2012	963881480	9/15/2012	2804	205.7	117.11	576782.8	328

{'input': 'How many sales channel are there?',
 'output': 'There are 2 sales channels.'}

In [21]:
df['Sales Channel'].value_counts()

Sales Channel
Offline    50
Online     50
Name: count, dtype: int64

In [30]:
firework_agent_executor.invoke({"input": "How many countries are there in the sale table?"})



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{'tool_input': ''}`


[0m[38;5;200m[1;3msales[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'sales'}`


[0m[33;1m[1;3m
CREATE TABLE sales (
	"Region" TEXT, 
	"Country" TEXT, 
	"Item Type" TEXT, 
	"Sales Channel" TEXT, 
	"Order Priority" TEXT, 
	"Order Date" TEXT, 
	"Order ID" BIGINT, 
	"Ship Date" TEXT, 
	"Units Sold" BIGINT, 
	"Unit Price" FLOAT, 
	"Unit Cost" FLOAT, 
	"Total Revenue" FLOAT, 
	"Total Cost" FLOAT, 
	"Total Profit" FLOAT
)

/*
3 rows from sales table:
Region	Country	Item Type	Sales Channel	Order Priority	Order Date	Order ID	Ship Date	Units Sold	Unit Price	Unit Cost	Total Revenue	Total Cost	Total Profit
Australia and Oceania	Tuvalu	Baby Food	Offline	H	5/28/2010	669165933	6/27/2010	9925	255.28	159.42	2533654.0	1582243.5	951410.5
Central America and the Caribbean	Grenada	Cereal	Online	C	8/22/2012	963881480	9/15/2012	2804	205.7	117.11	576782.8	328

{'input': 'How many countries are there in the sale table?',
 'output': 'There are 76 countries in the sales table.'}

In [31]:
df['Country'].value_counts()

Country
The Gambia               4
Sierra Leone             3
Sao Tome and Principe    3
Mexico                   3
Australia                3
                        ..
Comoros                  1
Iceland                  1
Macedonia                1
Mauritania               1
Mozambique               1
Name: count, Length: 76, dtype: int64

In [32]:
while True:
    question=input("Enter a question:")
    if question=='q':
        break
    firework_agent_executor.invoke(question)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{'tool_input': ''}`


[0m[38;5;200m[1;3msales[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'sales'}`


[0m[33;1m[1;3m
CREATE TABLE sales (
	"Region" TEXT, 
	"Country" TEXT, 
	"Item Type" TEXT, 
	"Sales Channel" TEXT, 
	"Order Priority" TEXT, 
	"Order Date" TEXT, 
	"Order ID" BIGINT, 
	"Ship Date" TEXT, 
	"Units Sold" BIGINT, 
	"Unit Price" FLOAT, 
	"Unit Cost" FLOAT, 
	"Total Revenue" FLOAT, 
	"Total Cost" FLOAT, 
	"Total Profit" FLOAT
)

/*
3 rows from sales table:
Region	Country	Item Type	Sales Channel	Order Priority	Order Date	Order ID	Ship Date	Units Sold	Unit Price	Unit Cost	Total Revenue	Total Cost	Total Profit
Australia and Oceania	Tuvalu	Baby Food	Offline	H	5/28/2010	669165933	6/27/2010	9925	255.28	159.42	2533654.0	1582243.5	951410.5
Central America and the Caribbean	Grenada	Cereal	Online	C	8/22/2012	963881480	9/15/2012	2804	205.7	117.11	576782.8	328