# 介绍langchain的数据库查询链

## langchain推荐的LCEL方式 create_sql_query_chain

### 建立一个本地的mysql数据库，建立数据库连接

In [1]:
from langchain_community.utilities import SQLDatabase

db_user = "root"
db_password = "161212"
db_host = "127.0.0.1"
db_name = "langchain"

#注意要安装pymysql这个库

db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")

print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM user;")


mysql
['user']


"[(1, 'John Doe', 'admin', 'p@ssword1'), (2, 'Alice Smith', 'manager', 'secure2'), (3, 'Bob Johnson', 'employee', 'pass123'), (4, 'Eva Davis', 'admin', 'adminPass'), (5, 'Charlie Brown', 'manager', '12345678'), (6, 'Grace White', 'employee', 'pwd987'), (7, 'Daniel Lee', 'admin', 'danny12'), (8, 'Olivia Moore', 'manager', 'pass432'), (9, 'Frank Miller', 'employee', 'fMiller'), (10, 'Sophia Taylor', 'admin', 'sophia7')]"

In [2]:
pip install pymysql

Collecting pymysqlNote: you may need to restart the kernel to use updated packages.



[notice] A new release of pip available: 22.2.2 -> 23.3.2
[notice] To update, run: python.exe -m pip install --upgrade pip



  Downloading PyMySQL-1.1.0-py3-none-any.whl (44 kB)
     -------------------------------------- 44.8/44.8 kB 548.1 kB/s eta 0:00:00
Installing collected packages: pymysql
Successfully installed pymysql-1.1.0


In [5]:
import os

from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(
    temperature=0,
    openai_api_key = os.getenv("OPENAI_API_KEY"),
    base_url = os.getenv("OPENAI_BASE_URL")
)

chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many people are there in the user table?"})
response

'SELECT COUNT(*) FROM `user`'

## 用 get_prompts()方法来查看langchain内置的prompt样式

In [6]:
chain.get_prompts()[0].pretty_print()

You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most [33;1m[1;3m{top_k}[0m results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQL

In [7]:
response = chain.invoke({"question": "user表里有多少用户?"})
response

'SELECT COUNT(*) FROM `user`'

In [8]:
response = chain.invoke({"question": "user表里有多少个admin用户?"})
response

"SELECT COUNT(*) FROM `user` WHERE `role` = 'admin'"

In [9]:
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
chain.invoke({"question": "user表里有多少个admin用户?"})

'[(4,)]'

## 构造一个回答问题的链answer，这个链的输入有question ,  SQL语句query， SQL执行结果 result

In [12]:
from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
"""Given the following user question, corresponding SQL query, and SQL result, answer the user question. 用中文回答最终答案

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)


answer = answer_prompt | llm | StrOutputParser()


## 构造一个 生成SQL query、执行SQL query 并将结果传递给 answer的最终链 final_chain

In [13]:
final_chain = (
    RunnablePassthrough.assign(query=write_query).assign(result=itemgetter("query") | execute_query)
    | answer
)

final_chain.invoke({"question": "user表里有多少个admin用户?"})

'user表里有4个admin用户。'

In [18]:
print(final_chain)

first=RunnableAssign(mapper={
  query: {
           input: RunnableLambda(...),
           top_k: RunnableLambda(...),
           table_info: RunnableLambda(...)
         }
         | PromptTemplate(input_variables=['input', 'table_info', 'top_k'], template='You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do no