In [1]:
import os
import langchain
from langchain.agents import *
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase

from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.schema import Document


from packages.paths import DS_LLM_PATHS
import json

PATHS = DS_LLM_PATHS()

In [65]:
langchain.__version__

'0.0.293'

In [2]:
with open(PATHS.EXTRACT_KEYS) as creds:
    db_keys = json.load(creds)

db_user = db_keys['db_user']
db_password = db_keys['db_password']
db_host = db_keys['db_host']
project_id = db_keys['project_id_gcp']
db_name = db_keys['db_name']

with open(PATHS.OPEN_AI) as creds:
    openai_keys = json.load(creds)

OPENAI_API_KEY = openai_keys['OPENAI_API_KEY']

In [3]:
embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")

In [None]:
few_shots = {'List all artists.': 'SELECT * FROM artists;',
             "Find all albums for the artist 'AC/DC'.": "SELECT * FROM albums WHERE ArtistId = (SELECT ArtistId FROM artists WHERE Name = 'AC/DC');",
             "List all tracks in the 'Rock' genre.": "SELECT * FROM tracks WHERE GenreId = (SELECT GenreId FROM genres WHERE Name = 'Rock');",
             'Find the total duration of all tracks.': 'SELECT SUM(Milliseconds) FROM tracks;',
             'List all customers from Canada.': "SELECT * FROM customers WHERE Country = 'Canada';",
             'How many tracks are there in the album with ID 5?': 'SELECT COUNT(*) FROM tracks WHERE AlbumId = 5;',
             'Find the total number of invoices.': 'SELECT COUNT(*) FROM invoices;',
             'List all tracks that are longer than 5 minutes.': 'SELECT * FROM tracks WHERE Milliseconds > 300000;',
             'Who are the top 5 customers by total purchase?': 'SELECT CustomerId, SUM(Total) AS TotalPurchase FROM invoices GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;',
             'Which albums are from the year 2000?': "SELECT * FROM albums WHERE strftime('%Y', ReleaseDate) = '2000';",
             'How many employees are there': 'SELECT COUNT(*) FROM "employee"'
            }

In [None]:
few_shot_docs = [Document(page_content=question, metadata={'sql_query': few_shots[question]}) for question in few_shots.keys()]
vector_db = FAISS.from_documents(few_shot_docs, embeddings)
retriever = vector_db.as_retriever()

In [None]:
from langchain.agents.agent_toolkits import create_retriever_tool

tool_description = """
This tool will help you understand similar examples to adapt them to the user question.
Input to this tool should be the user question.
"""

retriever_tool = create_retriever_tool(
        retriever,
        name='sql_get_similar_examples',
        description=tool_description
    )
custom_tool_list = [retriever_tool]

In [7]:
gpt_3_5_turbo_16k = 'gpt-3.5-turbo-16k'
gpt_3_5_instruct = 'gpt-3.5-turbo-instruct'
gpt_4 = 'gpt-4'

In [8]:
llm = ChatOpenAI(model_name=gpt_4, openai_api_key=OPENAI_API_KEY, temperature=0)
#llm= OpenAI(model='gpt-3.5-turbo-instruct', temperature=0, openai_api_key=OPENAI_API_KEY)

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

'''
custom_suffix = """
I should first get the similar examples I know.
If the examples are enough to construct the query, I can build it.
Otherwise, I can then look at the tables in the database to see what I can query.
Then I should query the schema of the most relevant tables
"""

agent_executor = create_sql_agent(llm=llm,
                        toolkit=toolkit,
                        verbose=True,
                        agent_type=AgentType.OPENAI_FUNCTIONS,
                        extra_tools=custom_tool_list,
                        suffix=custom_suffix
                        )
'''

agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
    handle_parsing_errors=True,
)


In [9]:
agent_executor.run("How much did we sell on amazon last month, and how much on web? Give me net revenue. Use the table Data_Metrics. Username = incrediwear@dataspeaks.io")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ''[0m
Observation: [38;5;200m[1;3mData_Metrics, actual_prior_data, calendar_events, expenses, experiments, merged_data, profit_margin, targets, users[0m
Thought:[32;1m[1;3mThe Data_Metrics table seems to be the most relevant one for this query. I should check its schema to understand its structure and the data it contains.
Action: sql_db_schema
Action Input: 'Data_Metrics'[0m
Observation: [33;1m[1;3mError: table_names {"'Data_Metrics'"} not found in database[0m
Thought:[32;1m[1;3mIt seems like there was a mistake in the table name. I should check the schema for the correct table name.
Action: sql_db_list_tables
Action Input: ''[0m
Observation: [38;5;200m[1;3mData_Metrics, actual_prior_data, calendar_events, expenses, experiments, merged_data, profit_margin, targets, users[0m
Thought:[32;1m[1;3mThe table name seems to be correct. Maybe there was a mistake in the input 

'The net revenue from Amazon last month was approximately $32,225.91, $28,688.57, $34,642.77, $34,404.94, $33,488.04, $30,255.22, $32,943.99, $31,867.94, $30,453.29, $28,492.75. The net revenue from the web last month was approximately $25,190.81, $17,694.8, $20,941.48, $21,967.36, $18,990.05, $20,990.06, $22,560.88, $22,258.45, $23,832.99, $21,259.78.'

In [10]:
agent_executor.run("How much did we sell last week? Give your answer in bullet points for each day")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: ""[0m
Observation: [38;5;200m[1;3mData_Metrics, actual_prior_data, calendar_events, expenses, experiments, merged_data, profit_margin, targets, users[0m
Thought:[32;1m[1;3mThe 'Data_Metrics' and 'actual_prior_data' tables seem like they could contain sales data. I should check their schemas to see if they have the information I need.
Action: sql_db_schema
Action Input: "Data_Metrics, actual_prior_data"[0m
Observation: [33;1m[1;3m
CREATE TABLE `Data_Metrics` (
	`Direct traffic (GA4)` BIGINT, 
	`Google DSP profit` DOUBLE, 
	`Google DSP ROAS` DOUBLE, 
	`Amazon DSP ROAS` DOUBLE, 
	`Amazon DSP profit` DOUBLE, 
	`Social media revenue (GA4)` DOUBLE, 
	`SEO revenue (GA4)` DOUBLE, 
	`SEO revenue` DOUBLE, 
	`Social media revenue` DOUBLE, 
	`Net sales (all)` DOUBLE, 
	`Orders (Amazon)` BIGINT, 
	`Sales (Amazon)` DOUBLE, 
	`SMS signup rate` DOUBLE, 
	`Email signup rate` DOUBLE, 
	`Sales (

OutputParserException: Could not parse LLM output: `I'm sorry, but I can't provide the observation without access to the actual database. In a real scenario, the observation would be the result of the query, which would be the sales data for each day of last week.`

In [None]:
agent_executor.run("Describe the table Data_Metrics. Enumerate the fields")