In [None]:
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain_core.runnables import RunnablePassthrough
from pyprojroot import here
from operator import itemgetter
from langchain_openai import ChatOpenAI
import os
from dotenv import load_dotenv
load_dotenv()

True

### LLM and DB connections

In [4]:
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)


In [None]:
sqldb_directory = here("data/db/imdb_sample.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")

### Agent writes SQL query

In [7]:
write_query = create_sql_query_chain(
    llm, db)

In [18]:
question = "How many tables do I have in the database? and what are their names?"
sql_query = write_query.invoke({"question": question})
sql_query

"SELECT name FROM sqlite_master WHERE type='table';"

### Executing query

In [19]:
""" Write and execute query """
from langchain_community.tools import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)

sql_result = execute_query.invoke({"query": sql_query})
sql_result


"[('title.ratings',), ('title.principals',), ('title.akas',), ('name.basics',), ('title.basics',), ('title.episode',), ('title.crew',)]"

### LLM answers question

In [None]:
""" Answer question in a user friendly way """

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, corresponding SQL query, and SQL result, answer the users question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """ 
)

answer = answer_prompt | llm | StrOutputParser()
answer.invoke({"question": question, "query": sql_query, "result": sql_result})

"You have 7 tables in the database. Their names are 'title.ratings', 'title.principals', 'title.akas', 'name.basics', 'title.basics', 'title.episode', and 'title.crew'."

### Complete chain

In [27]:
system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
    Question: {question}\n
    SQL Query: {query}\n
    SQL Result: {result}\n
    Answer:
    """

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(
    llm, db)
answer_prompt = PromptTemplate.from_template(
    system_role)
answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

In [28]:
message = "How many tables do I have in the database? and what are their names?"
response = chain.invoke({"question": message})
response

'You have 7 tables in the database. Their names are: title.ratings, title.principals, title.akas, name.basics, title.basics, title.episode, and title.crew.'

## 2. Agents

In [25]:
from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)

In [26]:
agent_executor.invoke(
    {
        "input": "List the highest rated movies, their titles, and the number of votes. The listed movies need to have at least 10000 votes."
    }
)



[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mname.basics, title.akas, title.basics, title.crew, title.episode, title.principals, title.ratings[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'title.ratings, title.basics'}`


[0m[33;1m[1;3m
CREATE TABLE "title.basics" (
	tconst TEXT, 
	"titleType" TEXT, 
	"primaryTitle" TEXT, 
	"originalTitle" TEXT, 
	"isAdult" BIGINT, 
	"startYear" TEXT, 
	"endYear" TEXT, 
	"runtimeMinutes" TEXT, 
	genres TEXT
)

/*
3 rows from title.basics table:
tconst	titleType	primaryTitle	originalTitle	isAdult	startYear	endYear	runtimeMinutes	genres
tt0000001	short	Carmencita	Carmencita	0	1894	None	1	Documentary,Short
tt0000002	short	Le clown et ses chiens	Le clown et ses chiens	0	1892	None	5	Animation,Short
tt0000003	short	Poor Pierrot	Pauvre Pierrot	0	1892	None	5	Animation,Comedy,Romance
*/


CREATE TABLE "title.ratings" (
	tconst TEXT, 
	"averageRating" FLO

{'input': 'List the highest rated movies, their titles, and the number of votes. The listed movies need to have at least 10000 votes.',
 'output': 'The highest rated movie with at least 10000 votes is "The Arrival of a Train" with an average rating of 7.4 and 13346 votes.'}

In [29]:

answer_prompt = PromptTemplate.from_template(
    """Given the following user question, and SQL result, answer the users question.

Question: {input}
SQL Result: {output}
Answer: """
)

answer = answer_prompt | llm | StrOutputParser()
chain = agent_executor | answer

chain.invoke({"input": "List the highest rated movies, their titles, and the number of votes. The listed movies need to have at least 10000 votes."})




[1m> Entering new SQL Agent Executor chain...[0m
[32;1m[1;3m
Invoking: `sql_db_list_tables` with `{}`


[0m[38;5;200m[1;3mname.basics, title.akas, title.basics, title.crew, title.episode, title.principals, title.ratings[0m[32;1m[1;3m
Invoking: `sql_db_schema` with `{'table_names': 'title.ratings'}`


[0m[33;1m[1;3m
CREATE TABLE "title.ratings" (
	tconst TEXT, 
	"averageRating" FLOAT, 
	"numVotes" BIGINT
)

/*
3 rows from title.ratings table:
tconst	averageRating	numVotes
tt0000001	5.7	2127
tt0000002	5.6	286
tt0000003	6.5	2163
*/[0m[32;1m[1;3m
Invoking: `sql_db_query` with `{'query': 'SELECT tconst, averageRating, numVotes FROM "title.ratings" WHERE numVotes >= 10000 ORDER BY averageRating DESC LIMIT 10'}`


[0m[36;1m[1;3m[('tt0000012', 7.4, 13346)][0m[32;1m[1;3mThe highest rated movie with at least 10000 votes is the movie with the title "tt0000012" and an average rating of 7.4 with 13346 votes.[0m

[1m> Finished chain.[0m


'The highest rated movie with at least 10000 votes is the movie with the title "tt0000012" with an average rating of 7.4 and a total of 13346 votes.'