# SQL 使用示例

In [1]:
from dotenv import load_dotenv

load_dotenv()

True

In [14]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")

def get_schema(_):
    return db.get_table_info()
def run_query(query):
    return db.run(query)

In [17]:
print(get_schema(_))


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Empl

In [21]:
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core .runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser

model = ChatOpenAI()
prompt = ChatPromptTemplate.from_template("""Based on the table schema below, write an SQL query that would answer user's question:
{schema}
                                 
Question:{question}
SQL Query:
""")

sql_query = (
    # passing data through
    # RunnablePassthrough()就是单纯的捕获input
    # RunnablePassthrough.assign 捕获 input + extra 的参数
    # extra的value必须是 runnables or callables，说明后续一定会执行
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | model
    | StrOutputParser()
)
 
sql_query.invoke({"question": "How many employees are there?"})


'SELECT COUNT(*) AS TotalEmployees\nFROM Employee;'

In [25]:
from operator import itemgetter

prompt = ChatPromptTemplate.from_template("""
Based on the table schema below, question, sql query, and sql response, write a natural language response:
{schema}

Question:{question}                                
SQL Query:{sql_query}
SQL Response:{sql_response}
""")

# 串行 question + schema -> query -> run_query -> response -> natural language
chain = (
    RunnablePassthrough.assign(sql_query=sql_query).assign(
        schema=get_schema,
        sql_response=lambda x: run_query(x["sql_query"])
    )
    | prompt
    | model
)

chain.invoke({"question": "How many employees are there?"})

AIMessage(content='There are a total of 8 employees in the database.')