In [2]:
# read Google Palm API key and set environment variable
import os
with open("../keys/application_llm_key.txt") as f:
    os.environ["google_api_key"] = f.read()

In [6]:
from langchain.llms import GooglePalm
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.sql_database import SQLDatabase
from langchain.agents import AgentExecutor
from langchain.agents.agent_types import AgentType

# define the LLM

llm = GooglePalm(google_api_key=os.getenv("google_api_key"))

# Database URI format (postgres using psycopg2)
# postgresql+psycopg2://username:password@host:port/mydatabase

db = SQLDatabase.from_uri("postgresql+psycopg2://postgres:postgres@localhost:5432/postgres")

# SQL toolkit

toolkit = SQLDatabaseToolkit(db=db, llm=llm)

# SQL agent

agent = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION
)



In [9]:
agent.run("The public.employee table contains information on employees. How many employees are there?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3memployee[0m
Thought:[32;1m[1;3mI know the schema of the employee table, so I can write a query.
Action: sql_db_schema
Action Input: employee[0m
Observation: [33;1m[1;3m
CREATE TABLE employee (
	emp_id INTEGER NOT NULL, 
	first_name VARCHAR(50) NOT NULL, 
	last_name VARCHAR(50), 
	salary NUMERIC(5, 1), 
	date_of_joining DATE, 
	CONSTRAINT emp_pk PRIMARY KEY (emp_id)
)

/*
3 rows from employee table:
emp_id	first_name	last_name	salary	date_of_joining
1001	Amit	Verma	1000.0	2021-02-23
1002	Sumit	Agarwal	1200.0	2020-02-23
1003	Anuj	Sharma	1500.0	2020-02-23
*/[0m
Thought:[32;1m[1;3mI can get the number of employees by counting the number of rows in the employee table.
Action: sql_db_query
Action Input: SELECT count(*) FROM employee[0m
Observation: [36;1m[1;3m[(6,)][0m
Thought:[32;1m[1;3mI now know the final answer
Final Answer: 6[0m

[1m> 

'6'

In [10]:
agent.run("The public.employee table contains information on employees. How many employees earn more than Sumit?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3memployee[0m
Thought:[32;1m[1;3mI found the employee table. I should query the schema of the employee table to find out what columns it has.
Action: sql_db_schema
Action Input: employee[0m
Observation: [33;1m[1;3m
CREATE TABLE employee (
	emp_id INTEGER NOT NULL, 
	first_name VARCHAR(50) NOT NULL, 
	last_name VARCHAR(50), 
	salary NUMERIC(5, 1), 
	date_of_joining DATE, 
	CONSTRAINT emp_pk PRIMARY KEY (emp_id)
)

/*
3 rows from employee table:
emp_id	first_name	last_name	salary	date_of_joining
1001	Amit	Verma	1000.0	2021-02-23
1002	Sumit	Agarwal	1200.0	2020-02-23
1003	Anuj	Sharma	1500.0	2020-02-23
*/[0m
Thought:[32;1m[1;3mI know that the employee table has a column named salary. I should write a query to get the number of employees who earn more than Sumit.
Action: sql_db_query
Action Input: SELECT count(*) FROM employee WHERE salary > 1200.0[

'4'

In [11]:
agent.run("What is the average salary of employees who joined after 2020?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: sql_db_list_tables
Action Input: [0m
Observation: [38;5;200m[1;3memployee[0m
Thought:[32;1m[1;3memployee table looks relevant. I should query the schema of it.
Action: sql_db_schema
Action Input: employee[0m
Observation: [33;1m[1;3m
CREATE TABLE employee (
	emp_id INTEGER NOT NULL, 
	first_name VARCHAR(50) NOT NULL, 
	last_name VARCHAR(50), 
	salary NUMERIC(5, 1), 
	date_of_joining DATE, 
	CONSTRAINT emp_pk PRIMARY KEY (emp_id)
)

/*
3 rows from employee table:
emp_id	first_name	last_name	salary	date_of_joining
1001	Amit	Verma	1000.0	2021-02-23
1002	Sumit	Agarwal	1200.0	2020-02-23
1003	Anuj	Sharma	1500.0	2020-02-23
*/[0m
Thought:[32;1m[1;3mI should write a query to get the average salary of employees who joined after 2020
Action: sql_db_query
Action Input: SELECT AVG(salary) FROM employee WHERE date_of_joining > '2020-01-01'[0m
Observation: [36;1m[1;3m[(Decimal('1983.3833333333333333'),)][0m
Thought:[

'1983.38'