In [1]:
import os
from dotenv import load_dotenv

os.environ['OPENAI_API_KEY'] = os.environ.get("OPENAI_API_KEY")

In [2]:
from langchain_core.prompts import ChatPromptTemplate

template = """
Based on the table schema below, write a SQL query that would answer the user's question.
{schema}

Question: {question}
SQL Query:
"""
prompt = ChatPromptTemplate.from_template(template)

In [3]:
prompt.format(schema="my schema",question="how many users are there")

"Human: \nBased on the table schema below, write a SQL query that would answer the user's question.\nmy schema\n\nQuestion: how many users are there\nSQL Query:\n"

In [4]:
from langchain_community.utilities import SQLDatabase

db_uri = "mysql+mysqlconnector://root:root@localhost:3306/chinook"
db = SQLDatabase.from_uri(db_uri)

In [5]:
def get_schema(_):
    return db.get_table_info()

In [6]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

sql_chain  = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop="\nSQL Result:")
    | StrOutputParser()
)

In [7]:
sql_chain.invoke({"question":"how many artists are there?"})

'SELECT COUNT(ArtistId) AS TotalArtists FROM artist;'

In [8]:
template = """
Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""

prompt = ChatPromptTemplate.from_template(template)

In [9]:
def run_query(query):
    return db.run(query)

In [10]:
run_query("SELECT COUNT(ArtistId) AS TotalArtists FROM Artist;")

'[(275,)]'

In [11]:
full_chain =(
    RunnablePassthrough.assign(query = sql_chain).assign(
        schema = get_schema,
        response = lambda vars:run_query(vars["query"])
    )
    | prompt
    | llm
    | StrOutputParser()
)

In [12]:
full_chain.invoke({"question":"how many customers are there?"})

'There are a total of 59 customers in the database.'