Skip to content

Commit

Permalink
增加text2sql工具,支持特定表、智能判定表,支持对表名进行额外说明 (#4154)
Browse files Browse the repository at this point in the history
* 1、增加text2sql工具,支持特定表、智能判定表,支持对表名进行额外说明
  • Loading branch information
srszzw committed Jun 8, 2024
1 parent 94524f8 commit b1c5bf9
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 5 deletions.
21 changes: 20 additions & 1 deletion libs/chatchat-server/chatchat/configs/_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,5 +227,24 @@
"text2images": {
"use": False,
},

"text2sql": {
"use": False,
#mysql连接信息
"db_host": "mysql_host",
"db_user": "mysql_user",
"db_password": "mysql_password",
"db_name": "mysql_database_name",
#限定返回的行数
"top_k":50,
#是否返回中间步骤
"return_intermediate_steps": True,
#如果想指定特定表,请填写表名称,如["sys_user","sys_dept"],不填写走智能判断应该使用哪些表
"table_names":[],
#对表名进行额外说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判。
"table_comments":{
# 如果出现大模型选错表的情况,可尝试根据实际情况填写表名和说明
# "tableA":"用户表",
# "tanleB":"角色表",
}
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@

from .vqa_processor import vqa_processor
from .aqa_processor import aqa_processor
from .text2sql import text2sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain,SQLDatabaseSequentialChain
from chatchat.server.utils import get_tool_config
from chatchat.server.pydantic_v1 import Field
from .tools_registry import regist_tool, BaseToolOutput

def query_database(query: str,
config: dict):
db_user = config["db_user"]
db_password = config["db_password"]
db_host = config["db_host"]
db_name = config["db_name"]
top_k = config["top_k"]
return_intermediate_steps = config["return_intermediate_steps"]
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}")
from chatchat.server.api_server.chat_routes import global_model_name
from chatchat.server.utils import get_ChatOpenAI
llm = get_ChatOpenAI(
model_name=global_model_name,
temperature=0,
streaming=True,
local_wrap=True,
verbose=True
)
table_names=config["table_names"]
table_comments=config["table_comments"]
result = None

#如果发现大模型判断用什么表出现问题,尝试给langchain提供额外的表说明,辅助大模型更好的判断应该使用哪些表,尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判
#由于langchain固定了输入参数,所以只能通过query传递额外的表说明
if table_comments:
TABLE_COMMNET_PROMPT="\n\nI will provide some special notes for a few tables:\n\n"
table_comments_str="\n".join([f"{k}:{v}" for k,v in table_comments.items()])
query=query+TABLE_COMMNET_PROMPT+table_comments_str+"\n\n"

#如果不指定table_names,优先走SQLDatabaseSequentialChain,这个链会先预测需要哪些表,然后再将相关表输入SQLDatabaseChain
#这是因为如果不指定table_names,直接走SQLDatabaseChain,Langchain会将全量表结构传递给大模型,可能会因token太长从而引发错误,也浪费资源
#如果指定了table_names,直接走SQLDatabaseChain,将特定表结构传递给大模型进行判断
if len(table_names) > 0:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps)
result = db_chain.invoke({"query":query,"table_names_to_use":table_names})
else:
#先预测会使用哪些表,然后再将问题和预测的表给大模型
db_chain = SQLDatabaseSequentialChain.from_llm(llm, db, verbose=True,top_k=top_k,return_intermediate_steps=return_intermediate_steps)
result = db_chain.invoke(query)

context = f"""查询结果:{result['result']}\n\n"""

intermediate_steps=result["intermediate_steps"]
#如果存在intermediate_steps,且这个数组的长度大于2,则保留最后两个元素,因为前面几个步骤存在示例数据,容易引起误解
if intermediate_steps:
if len(intermediate_steps)>2:
sql_detail=intermediate_steps[-2:-1][0]["input"]
# sql_detail截取从SQLQuery到Answer:之间的内容
sql_detail=sql_detail[sql_detail.find("SQLQuery:")+9:sql_detail.find("Answer:")]
context = context+"执行的sql:'"+sql_detail+"'\n\n"
return context


@regist_tool(title="Text2Sql")
def text2sql(query: str = Field(description="No need for SQL statements,just input the natural language that you want to chat with database")):
'''Use this tool to chat with database,Input natural language, then it will convert it into SQL and execute it in the database, then return the execution result.'''
tool_config = get_tool_config("text2sql")
return BaseToolOutput(query_database(query=query, config=tool_config))
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
summary="文件对话"
)(file_chat)

#定义全局model信息,用于给Text2Sql中的get_ChatOpenAI提供model_name
global_model_name=None

@chat_router.post("/chat/completions", summary="兼容 openai 的统一 chat 接口")
async def chat_completions(
Expand All @@ -51,6 +53,8 @@ async def chat_completions(
for key in list(extra):
delattr(body, key)

global global_model_name
global_model_name=body.model
# check tools & tool_choice in request body
if isinstance(body.tool_choice, str):
if t := get_tool(body.tool_choice):
Expand Down
9 changes: 5 additions & 4 deletions libs/chatchat-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ chatchat-kb = 'chatchat.init_database:main'
[tool.poetry.dependencies]
python = ">=3.8.1,<3.12,!=3.9.7"
model-providers = "^0.3.0"
langchain = "0.1.5"
langchain = "0.1.17"
langchainhub = "0.1.14"
langchain-community = "0.0.17"
langchain-community = "0.0.36"
langchain-openai = "0.0.5"
langchain-experimental = "0.0.50"
langchain-experimental = "0.0.58"
fastapi = "~0.109.2"
sse_starlette = "~1.8.2"
nltk = "~3.8.1"
Expand Down Expand Up @@ -51,12 +51,13 @@ python-multipart = "0.0.9"
streamlit = "1.34.0"
streamlit-option-menu = "0.3.12"
streamlit-antd-components = "0.3.1"
streamlit-chatbox = "1.1.12"
streamlit-chatbox = "1.1.12.post2"
streamlit-modal = "0.1.0"
streamlit-aggrid = "0.3.4.post3"
streamlit-extras = "0.4.2"
xinference_client = { version = "^0.11.1", optional = true }
zhipuai = { version = "^2.1.0", optional = true }
pymysql = "^1.1.0"

[tool.poetry.extras]
xinference = ["xinference_client"]
Expand Down

0 comments on commit b1c5bf9

Please sign in to comment.