In [1]:
import os
os.environ["OPENAI_API_KEY"] = "YOUR_OPENAI_API_KEY"

In [2]:
from langchain_core.prompts import ChatPromptTemplate
template = """
Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:
"""

prompt = ChatPromptTemplate.from_template(template)

In [3]:
prompt.format(schema="my schema", question="how many users are there?")

"Human: \nBased on the table schema below, write a SQL query that would answer the user's question:\nmy schema\n\nQuestion: how many users are there?\nSQL Query:\n"

In [6]:
from langchain_community.utilities import SQLDatabase

db_uri = "mysql+mysqlconnector://root:newpassword@localhost:3306/Chinook"
db = SQLDatabase.from_uri(db_uri)

In [9]:
db.run("SELECT * FROM Album LIMIT 5")

"[(1, 'For Those About To Rock We Salute You', 1), (2, 'Balls to the Wall', 2), (3, 'Restless and Wild', 2), (4, 'Let There Be Rock', 1), (5, 'Big Ones', 3)]"

In [14]:
def get_schema(input: dict) -> str:
    return db.get_table_info()

In [15]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt 
    | llm.bind(stop="\nSQL Result:")
    | StrOutputParser()
)

In [16]:
sql_chain.invoke({"question": "how many artists are there?"})

'SELECT COUNT(DISTINCT ArtistId) AS NumberOfArtists\nFROM Album;'

In [17]:
template = """
Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}
"""

prompt = ChatPromptTemplate.from_template(template)

In [18]:
def run_query(query: str):
    return db.run(query)

In [21]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response= lambda vars: run_query(vars["query"])
    )
    | prompt
    | llm
    | StrOutputParser()
)

In [22]:
full_chain.invoke({"question": "how many artists are there?"})

'There are a total of 275 artists in the database.'