## 检索器：查询重构
****
- 用自然语言来查询SQL

#### 查询SQL
***
- 步骤1: 将问题转换为 SQL 查询，模型将用户输入转换为 SQL 查询。
- 步骤2: 执行 SQL 查询，执行查询。
- 步骤3: 回答问题，模型使用查询结果响应用户输入。

In [1]:
! pip install --upgrade --quiet langchain-community langchainhub langgraph


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


查看数据库

In [2]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

使用hub上预制的提示词

In [3]:
from langchain import hub

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

assert len(query_prompt_template.messages) == 1
query_prompt_template.messages[0].pretty_print()


Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m

Question: [33;1m[1;3m{input}[0m


In [7]:
from langchain_deepseek import ChatDeepSeek
import os

llm = ChatDeepSeek(
    model="Pro/deepseek-ai/DeepSeek-V3",
    temperature=0,
    api_key=os.environ.get("DEEPSEEK_API_KEY"),
    api_base=os.environ.get("DEEPSEEK_API_BASE"),
)

使用LCEL创建一个最简单的SQL查询

In [8]:
from typing_extensions import Annotated
from typing_extensions import TypedDict

# Define the state type
class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str

# Define the output type
class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]

# Define the write_query function
def write_query(state: State):
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

In [9]:
sqlMessage = write_query({"question": "一共有多少个员工?"})
print(sqlMessage)

{'query': 'SELECT COUNT(*) AS TotalEmployees FROM Employee;'}


得到的SQL语句可以接着进行执行 ⚠️ 此操作有风险

In [10]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool


def execute_query(state: State):
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}

In [11]:
execute_query(sqlMessage)

{'result': '[(8,)]'}

In [12]:
from langchain_core.runnables import RunnablePassthrough

# Define the chain to answer questions from SQL query
def answer_question(state: State):
    """Format answer based on the query result."""
    prompt = f"""Based on the SQL query:
{state["query"]}

And the query result:
{state["result"]}

Answer the user's question: {state["question"]}
Provide a concise and informative response.
"""
    return {"answer": llm.invoke(prompt).content}

# Create a full chain from question to answer
sql_chain = (
    RunnablePassthrough.assign(query=write_query)
    .assign(result=execute_query)
    .assign(answer=answer_question)
)

# Example usage
question = "获取销售额最高的5位员工及其销售总额"
response = sql_chain.invoke({"question": question})

print("Question:", question)
print("\nGenerated SQL:")
print(response["query"])
print("\nExecution Result:")
print(response["result"])
print("\nAnswer:")
print(response["answer"])

Question: 获取销售额最高的5位员工及其销售总额

Generated SQL:
{'query': 'SELECT e.EmployeeId, e.FirstName, e.LastName, SUM(i.Total) AS TotalSales FROM Employee e JOIN Customer c ON e.EmployeeId = c.SupportRepId JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY e.EmployeeId ORDER BY TotalSales DESC LIMIT 5'}

Execution Result:
{'result': "[(3, 'Jane', 'Peacock', 833.04), (4, 'Margaret', 'Park', 775.4), (5, 'Steve', 'Johnson', 720.16)]"}

Answer:
{'answer': '根据查询结果，销售额最高的5位员工及其销售总额如下：\n\n1. **Jane Peacock** - 销售总额: 833.04\n2. **Margaret Park** - 销售总额: 775.40\n3. **Steve Johnson** - 销售总额: 720.16\n\n这些员工按照销售总额从高到低排列，仅显示了前三位员工的信息。'}
