In [1]:
from langchain_core.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)


In [11]:
import os
from dotenv import load_dotenv
load_dotenv()

os.environ['OPENAI_API_KEY']=os.getenv("OPENAI_API_KEY")

In [2]:
prompt.format(schema="my schema",question="how many users are there?")

"Human: Based on the table schema below, write a SQL query that would answer the user's question:\nmy schema\n\nQuestion: how many users are there?\nSQL Query:"

In [7]:
from langchain_community.utilities import SQLDatabase

# if you are using SQLite
#sqlite_uri = 'sqlite:///./Chinook.db' 

# if you are using MySQL
mysql_uri = 'mysql+mysqlconnector://root:Mohan%40123@localhost:3306/chinook'

db = SQLDatabase.from_uri(mysql_uri)

In [8]:
def get_schema(db):
    schema = db.get_table_info()
    return schema

In [10]:
db.run("SELECT * FROM Album LIMIT 5")

"[(1, 'For Those About To Rock We Salute You', 1), (2, 'Balls to the Wall', 2), (3, 'Restless and Wild', 2), (4, 'Let There Be Rock', 1), (5, 'Big Ones', 3)]"

In [15]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI

llm = ChatOpenAI()

sql_chain = (
    RunnablePassthrough.assign(schema=lambda _:get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [16]:
user_question = 'how many albums are there in the database?'
#sql_chain.invoke({"question": user_question})
sql_chain.invoke({"question": 'how many albums are there in the database?'})

# 'SELECT COUNT(*) AS TotalAlbums\nFROM Album;'

'SELECT COUNT(*) as total_albums\nFROM albums'

In [17]:
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""
prompt_response = ChatPromptTemplate.from_template(template)

In [18]:
def run_query(query):
    return db.run(query)

In [25]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=lambda _:get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | prompt_response
    | llm
)

In [26]:
user_question = 'how many album are there in the database?'
full_chain.invoke({"question": user_question})

AIMessage(content='There are 347 albums in the database.', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 10, 'prompt_tokens': 76, 'total_tokens': 86, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-f4aaf9f4-c3d3-4d98-8230-2a7e514442b6-0', usage_metadata={'input_tokens': 76, 'output_tokens': 10, 'total_tokens': 86, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})