In [1]:
from langchain.prompts import ChatPromptTemplate

template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [2]:
from langchain.utilities import SQLDatabase


In [3]:
db = SQLDatabase.from_uri("sqlite:///./Chinook.db")

In [4]:
def get_schema(_):
    return db.get_table_info()

In [5]:
def run_query(query):
    return db.run(query)

In [6]:
from langchain.chat_models import ChatOpenAI
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough

model = ChatOpenAI()

sql_response = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [7]:
sql_response.invoke({"question": "How many employees are there?"})

'SELECT COUNT(*) FROM employees;'