In [1]:
import getpass
import os

os.environ["OPENAI_API_KEY"] = 'Enter your key'

In [2]:
from langchain_core.prompts import ChatPromptTemplate

template = """
Based on the schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:
"""

prompt = ChatPromptTemplate.from_template(template=template)

In [3]:
prompt.format(schema="my schema", question="How many users are there?")

"Human: \nBased on the schema below, write a SQL query that would answer the user's question:\nmy schema\n\nQuestion: How many users are there?\nSQL Query:\n"

In [4]:
from langchain_community.utilities import SQLDatabase

db_uri = "sqlite:///Chinook.db"
db = SQLDatabase.from_uri(database_uri=db_uri)

In [5]:
db.run("SELECT * from Artist")

'[(1, \'AC/DC\'), (2, \'Accept\'), (3, \'Aerosmith\'), (4, \'Alanis Morissette\'), (5, \'Alice In Chains\'), (6, \'Antônio Carlos Jobim\'), (7, \'Apocalyptica\'), (8, \'Audioslave\'), (9, \'BackBeat\'), (10, \'Billy Cobham\'), (11, \'Black Label Society\'), (12, \'Black Sabbath\'), (13, \'Body Count\'), (14, \'Bruce Dickinson\'), (15, \'Buddy Guy\'), (16, \'Caetano Veloso\'), (17, \'Chico Buarque\'), (18, \'Chico Science & Nação Zumbi\'), (19, \'Cidade Negra\'), (20, \'Cláudio Zoli\'), (21, \'Various Artists\'), (22, \'Led Zeppelin\'), (23, \'Frank Zappa & Captain Beefheart\'), (24, \'Marcos Valle\'), (25, \'Milton Nascimento & Bebeto\'), (26, \'Azymuth\'), (27, \'Gilberto Gil\'), (28, \'João Gilberto\'), (29, \'Bebel Gilberto\'), (30, \'Jorge Vercilo\'), (31, \'Baby Consuelo\'), (32, \'Ney Matogrosso\'), (33, \'Luiz Melodia\'), (34, \'Nando Reis\'), (35, \'Pedro Luís & A Parede\'), (36, \'O Rappa\'), (37, \'Ed Motta\'), (38, \'Banda Black Rio\'), (39, \'Fernanda Porto\'), (40, \'Os Ca

In [7]:
result = db.run("SELECT * from Artist")

In [13]:
result

'[(1, \'AC/DC\'), (2, \'Accept\'), (3, \'Aerosmith\'), (4, \'Alanis Morissette\'), (5, \'Alice In Chains\'), (6, \'Antônio Carlos Jobim\'), (7, \'Apocalyptica\'), (8, \'Audioslave\'), (9, \'BackBeat\'), (10, \'Billy Cobham\'), (11, \'Black Label Society\'), (12, \'Black Sabbath\'), (13, \'Body Count\'), (14, \'Bruce Dickinson\'), (15, \'Buddy Guy\'), (16, \'Caetano Veloso\'), (17, \'Chico Buarque\'), (18, \'Chico Science & Nação Zumbi\'), (19, \'Cidade Negra\'), (20, \'Cláudio Zoli\'), (21, \'Various Artists\'), (22, \'Led Zeppelin\'), (23, \'Frank Zappa & Captain Beefheart\'), (24, \'Marcos Valle\'), (25, \'Milton Nascimento & Bebeto\'), (26, \'Azymuth\'), (27, \'Gilberto Gil\'), (28, \'João Gilberto\'), (29, \'Bebel Gilberto\'), (30, \'Jorge Vercilo\'), (31, \'Baby Consuelo\'), (32, \'Ney Matogrosso\'), (33, \'Luiz Melodia\'), (34, \'Nando Reis\'), (35, \'Pedro Luís & A Parede\'), (36, \'O Rappa\'), (37, \'Ed Motta\'), (38, \'Banda Black Rio\'), (39, \'Fernanda Porto\'), (40, \'Os Ca

In [6]:
def get_schema(_):
    schema = db.get_table_info()
    return schema

In [9]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

llm = ChatOpenAI(
    temperature=0,
    model='gpt-3.5-turbo',
    streaming=True,
    callbacks=[StreamingStdOutCallbackHandler()]
)

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)


In [10]:
user_question = 'Which are the tables in the schema?'
sql_chain.invoke({"question": user_question})


SELECT name
FROM sqlite_master
WHERE type = 'table';

"SELECT name\nFROM sqlite_master\nWHERE type = 'table';"

In [20]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)


In [21]:
def run_query(query):
    return db.run(query)


In [23]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | prompt_response
    | llm
)


In [27]:
user_question = 'how many countries are the users from?'
full_chain.invoke({"question": user_question})

# 'There are 347 albums in the database.'


AIMessage(content='There are 24 different countries that the users are from.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 12, 'prompt_tokens': 2188, 'total_tokens': 2200}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-31931966-8ac0-4d74-a50c-292031d4f349-0', usage_metadata={'input_tokens': 2188, 'output_tokens': 12, 'total_tokens': 2200})

In [29]:
user_question = 'list all the countries the users are from. also find the total of the countries'
full_chain.invoke({"question": user_question})

AIMessage(content='The users are from various countries such as Argentina, Australia, Brazil, Canada, Germany, USA, and more. In total, there are users from 23 different countries.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 35, 'prompt_tokens': 2352, 'total_tokens': 2387}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-d21e1827-f35f-4e68-bed6-f0effe8be267-0', usage_metadata={'input_tokens': 2352, 'output_tokens': 35, 'total_tokens': 2387})