Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix]增加api接口的并发性能,同步代码异步化防止单个耗时接口阻塞服务 #869

Closed
wants to merge 9 commits into from
90 changes: 50 additions & 40 deletions api.py
@@ -1,4 +1,4 @@
#encoding:utf-8
# encoding:utf-8
import argparse
import json
import os
Expand Down Expand Up @@ -300,7 +300,6 @@ async def update_doc(
return BaseResponse(code=500, msg=file_status)



async def local_doc_chat(
knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"),
question: str = Body(..., description="Question", example="工伤保险是什么?"),
Expand All @@ -327,26 +326,30 @@ async def local_doc_chat(
)
else:
if (streaming):
def generate_answer ():
def generate_answer():
last_print_len = 0
for resp, next_history in local_doc_qa.get_knowledge_based_answer(
query=question, vs_path=vs_path, chat_history=history, streaming=True
query=question, vs_path=vs_path, chat_history=history, streaming=True
):
yield resp["result"][last_print_len:]
last_print_len=len(resp["result"])
last_print_len = len(resp["result"])

return StreamingResponse(generate_answer())
else:
for resp, history in local_doc_qa.get_knowledge_based_answer(
query=question, vs_path=vs_path, chat_history=history, streaming=True
):
pass

source_documents = [
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
def _sync_method(question, history):
# 同步方法的代码
for resp, history in local_doc_qa.get_knowledge_based_answer(
query=question, vs_path=vs_path, chat_history=history, streaming=True
):
pass
source_documents = [
f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n"""
f"""相关度:{doc.metadata['score']}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
return resp, source_documents

resp, source_documents = await asyncio.to_thread(_sync_method, question=question, history=history)

return ChatMessage(
question=question,
Expand All @@ -369,15 +372,19 @@ async def bing_search_chat(
],
),
):
for resp, history in local_doc_qa.get_search_result_based_answer(
query=question, chat_history=history, streaming=True
):
pass
source_documents = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
def _sync_method(question, history):
# 同步方法的代码
for resp, history in local_doc_qa.get_search_result_based_answer(
query=question, chat_history=history, streaming=True
):
pass
source_documents = [
f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n"""
for inum, doc in enumerate(resp["source_documents"])
]
return resp, source_documents

resp, source_documents = await asyncio.to_thread(_sync_method, question=question, history=history)
return ChatMessage(
question=question,
response=resp["result"],
Expand All @@ -401,16 +408,28 @@ async def chat(
),
):
if (streaming):
def generate_answer ():
def generate_answer():
last_print_len = 0
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": question, "history": history, "streaming": True})
for answer_result in answer_result_stream_result['answer_result_stream']:
yield answer_result.llm_output["answer"][last_print_len:]
last_print_len = len(answer_result.llm_output["answer"])

return StreamingResponse(generate_answer())
else:
def _sync_method(question, history):
# 同步方法的代码
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": question, "history": history, "streaming": True})

for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
return resp, history

resp, history = await asyncio.to_thread(_sync_method, question=question, history=history)
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": question, "history": history, "streaming": True})
for answer_result in answer_result_stream_result['answer_result_stream']:
Expand All @@ -424,19 +443,6 @@ def generate_answer ():
history=history,
source_documents=[],
)
answer_result_stream_result = local_doc_qa.llm_model_chain(
{"prompt": question, "history": history, "streaming": True})

for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
history = answer_result.history
pass
return ChatMessage(
question=question,
response=resp,
history=history,
source_documents=[],
)


async def stream_chat(websocket: WebSocket):
Expand Down Expand Up @@ -482,6 +488,7 @@ async def stream_chat(websocket: WebSocket):
)
turn += 1


async def stream_chat_bing(websocket: WebSocket):
"""
基于bing搜索的流式问答
Expand All @@ -495,7 +502,8 @@ async def stream_chat_bing(websocket: WebSocket):
await websocket.send_json({"question": question, "turn": turn, "flag": "start"})

last_print_len = 0
for resp, history in local_doc_qa.get_search_result_based_answer(question, chat_history=history, streaming=True):
for resp, history in local_doc_qa.get_search_result_based_answer(question, chat_history=history,
streaming=True):
await websocket.send_text(resp["result"][last_print_len:])
last_print_len = len(resp["result"])

Expand All @@ -518,6 +526,7 @@ async def stream_chat_bing(websocket: WebSocket):
)
turn += 1


async def document():
return RedirectResponse(url="/docs")

Expand Down Expand Up @@ -561,7 +570,8 @@ def api_start(host, port, **kwargs):
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse, summary="获取知识库内的文件列表")(list_docs)
app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库")(delete_kb)
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse, summary="删除知识库内的文件")(delete_doc)
app.post("/local_doc_qa/update_file", response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件")(update_doc)
app.post("/local_doc_qa/update_file", response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件")(
update_doc)

local_doc_qa = LocalDocQA()
local_doc_qa.init_cfg(
Expand Down
42 changes: 42 additions & 0 deletions test/api/test_async.py
@@ -0,0 +1,42 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@DOC : 测试接口并发性
@Date :2023/7/17 15:25
"""

import requests
import time
import concurrent.futures

inputBody = {
"knowledge_base_id": None,
"question": "写一个python语言实现的二叉树demo",
# "question": "介绍一下自己?",
"history": []
}
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}


def send_request(url):
t1 = time.time()
print("start:", url)
response = requests.post(url, json=inputBody, headers=headers)
t2 = time.time()
print("Time taken:", t2 - t1)
print("end:", url)
return response.text


url = 'http://localhost:7861/chat'
# host = 'localhost:7861/local_doc_qa/bing_search_chat'


with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(send_request, url) for _ in range(5)]
for future in concurrent.futures.as_completed(futures):
response = future.result()
print("Response:", response)