# 检索器：查询重构
+ 用自然语言来查询SQL

## 查询SQL
+ 步骤一：将问题转为SQL查询，模型将用户输入转换为SQL语句
+ 步骤二：执行SQL查询，执行查询。
+ 步骤三：回答问题，模型使用查询结果响应用户输入。

In [7]:
import os
from dotenv import load_dotenv

# 加载 .env 文件中的环境变量
load_dotenv(override=True)  # 使用 override=True 确保加载最新的 .env 数据

True

In [1]:
! pip install langchainhub langgraph

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting langchainhub
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/35/63/40328157ddee807991f2f1992c2ad88f479b2472dc9e40d08ccf10700735/langchainhub-0.1.21-py3-none-any.whl (5.2 kB)
Collecting langgraph
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/67/06/f440922a58204dbfd10f7fdda0de0325529a159e9dc3d1038afe4b431a49/langgraph-0.6.7-py3-none-any.whl (153 kB)
Collecting langgraph-checkpoint<3.0.0,>=2.1.0 (from langgraph)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/4c/dd/64686797b0927fb18b290044be12ae9d4df01670dce6bb2498d5ab65cb24/langgraph_checkpoint-2.1.1-py3-none-any.whl (43 kB)
Collecting langgraph-prebuilt<0.7.0,>=0.6.0 (from langgraph)
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/0a/7f/973b0d9729d9693d6e5b4bc5f3ae41138d194cb7b16b0ed230020beeb13a/langgraph_prebuilt-0.6.4-py3-none-any.whl (28 kB)
Collecting langgraph-sdk<0.3.0,>=0.2.2 (from langgraph)
  Downloading https://pypi.t

### 查看数据库

In [22]:
from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
# print(db.get_table_info())
db.run("SELECT * FROM Artists LIMIT 10")

sqlite
['albums', 'artists', 'customers', 'employees', 'genres', 'invoice_items', 'invoices', 'media_types', 'playlist_track', 'playlists', 'tracks']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

### 使用Hub上的提示词

In [None]:
from langchain import hub

query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

print(query_prompt_template.pretty_print())
# assert len(query_prompt_template) == 1
# query_prompt_template.messages[0].pretty_print()


Given an input question, create a syntactically correct [33;1m[1;3m{dialect}[0m query to run to help find the answer. Unless the user specifies in his question a specific number of examples they wish to obtain, always limit your query to at most [33;1m[1;3m{top_k}[0m results. You can order the results by a relevant column to return the most interesting examples in the database.

Never query for all the columns from a specific table, only ask for a the few relevant columns given the question.

Pay attention to use only the column names that you can see in the schema description. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.

Only use the following tables:
[33;1m[1;3m{table_info}[0m


Question: [33;1m[1;3m{input}[0m
None


AssertionError: 

In [33]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(
    model=os.environ.get("OPENAPI_MODEL"),
    base_url=os.environ.get("OPENAPI_API_BASE"),
    api_key=os.environ.get("OPENAPI_API_KEY"),
    temperature=0,
)

In [None]:
# 使用LCEL创建一个最简单的SQL查询
from typing_extensions import Annotated
from typing_extensions import TypedDict


# 定义结构化对象
class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str


class QueryOutput(TypedDict):
    """生成SQL查询"""

    query: Annotated[str, ..., "Syntactically valid SQL query"]


def write_query(state: State) -> State:
    """生成SQL查询"""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}

In [35]:
sql_message = write_query({"question": "一共有多少员工?"})
print(sql_message)

OutputParserException: Invalid json output: 要查询员工的总数，我可以使用以下SQL查询：

```sql
SELECT COUNT(*) AS TotalEmployees 
FROM employees;
```

这个查询会计算employees表中的记录数量，也就是员工的总数。
For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/OUTPUT_PARSING_FAILURE 

### 得到SQL语句可以接着进行执行

In [36]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool


def execute_query(state: State) -> State:
    """执行SQL查询"""
    query = state["query"]
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(query)}

In [37]:
execute_query(sql_message)

NameError: name 'sql_message' is not defined

In [38]:
# 将上面内容创建一个完整的链
from langchain_core.runnables import RunnablePassthrough


# Define the chain to answer questions from SQL query
def answer_question(state: State):
    """Format answer based on the query result."""
    prompt = f"""Based on the SQL query:
    {state["query"]}
    
    And the query result:
    {state["result"]}
    
    Answer the user's question: {state["question"]}
    Provide a concise and informative response.
    """

    return {"answer": llm.invoke(prompt).contet}


# Create a full chain from question to answer
sql_chain = (
    RunnablePassthrough.assign(query=write_query)
    .assign(result=execute_query)
    .assign(answer=answer_question)
)

# 使用示例
question = "获取销售额最高的5位员工及其销售总额"
response = sql_chain.invoke({"question": question})

print("Question", question)
print("\nGenerated SQL:")
print(response["query"])
print("\nExecution Result:")
print(response["result"])
print("\nAnswer:")
print(response["answer"])

OutputParserException: Invalid json output: Here's the SQL query to get the top 5 employees with the highest sales totals:

```sql
SELECT 
    e.EmployeeId, 
    e.FirstName, 
    e.LastName, 
    SUM(i.Total) AS TotalSales
FROM 
    employees e
JOIN 
    customers c ON e.EmployeeId = c.SupportRepId
JOIN 
    invoices i ON c.CustomerId = i.CustomerId
GROUP BY 
    e.EmployeeId, e.FirstName, e.LastName
ORDER BY 
    TotalSales DESC
LIMIT 5;
```

This query:
1. Joins employees with customers they support (via SupportRepId)
2. Joins those customers with their invoices
3. Groups by employee and calculates the sum of invoice totals
4. Orders by the total sales in descending order
5. Limits to the top 5 results
For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/OUTPUT_PARSING_FAILURE 