From 3adcfc0b28598b3024b467865365c1631ad9a801 Mon Sep 17 00:00:00 2001 From: aleimu Date: Mon, 17 Jul 2023 16:05:58 +0800 Subject: [PATCH 1/2] =?UTF-8?q?[fix]=E5=A2=9E=E5=8A=A0api=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E7=9A=84=E5=B9=B6=E5=8F=91=E6=80=A7=E8=83=BD,?= =?UTF-8?q?=E5=90=8C=E6=AD=A5=E4=BB=A3=E7=A0=81=E5=BC=82=E6=AD=A5=E5=8C=96?= =?UTF-8?q?=E9=98=B2=E6=AD=A2=E5=8D=95=E4=B8=AA=E8=80=97=E6=97=B6=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E9=98=BB=E5=A1=9E=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api.py | 70 ++++++++++++++++++++++++++---------------- test/api/test_async.py | 42 +++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 26 deletions(-) create mode 100644 test/api/test_async.py diff --git a/api.py b/api.py index 70dccc8d1..c71b106f1 100644 --- a/api.py +++ b/api.py @@ -8,6 +8,7 @@ import nltk import pydantic import uvicorn +import concurrent.futures from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel @@ -323,15 +324,22 @@ async def local_doc_chat( source_documents=[], ) else: - for resp, history in local_doc_qa.get_knowledge_based_answer( - query=question, vs_path=vs_path, chat_history=history, streaming=True - ): - pass - source_documents = [ - f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" - f"""相关度:{doc.metadata['score']}\n\n""" - for inum, doc in enumerate(resp["source_documents"]) - ] + def _sync_method(question, history): + # 同步方法的代码 + for resp, history in local_doc_qa.get_knowledge_based_answer( + query=question, vs_path=vs_path, chat_history=history, streaming=True + ): + pass + source_documents = [ + f"""出处 [{inum + 1}] {os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + f"""相关度:{doc.metadata['score']}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] + return resp, source_documents + + loop = asyncio.get_event_loop() + executor = concurrent.futures.ThreadPoolExecutor() + resp, source_documents = await loop.run_in_executor(executor, _sync_method, question, history) return ChatMessage( question=question, @@ -354,15 +362,21 @@ async def bing_search_chat( ], ), ): - for resp, history in local_doc_qa.get_search_result_based_answer( - query=question, chat_history=history, streaming=True - ): - pass - source_documents = [ - f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" - for inum, doc in enumerate(resp["source_documents"]) - ] + def _sync_method(question, history): + # 同步方法的代码 + for resp, history in local_doc_qa.get_search_result_based_answer( + query=question, chat_history=history, streaming=True + ): + pass + source_documents = [ + f"""出处 [{inum + 1}] [{doc.metadata["source"]}]({doc.metadata["source"]}) \n\n{doc.page_content}\n\n""" + for inum, doc in enumerate(resp["source_documents"]) + ] + return resp, source_documents + loop = asyncio.get_event_loop() + executor = concurrent.futures.ThreadPoolExecutor() + resp, source_documents = await loop.run_in_executor(executor, _sync_method, question, history) return ChatMessage( question=question, response=resp["result"], @@ -370,7 +384,6 @@ async def bing_search_chat( source_documents=source_documents, ) - async def chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), history: List[List[str]] = Body( @@ -384,14 +397,20 @@ async def chat( ], ), ): - answer_result_stream_result = local_doc_qa.llm_model_chain( - {"prompt": question, "history": history, "streaming": True}) - - for answer_result in answer_result_stream_result['answer_result_stream']: - resp = answer_result.llm_output["answer"] - history = answer_result.history - pass + def _sync_method(question, history): + # 同步方法的代码 + answer_result_stream_result = local_doc_qa.llm_model_chain( + {"prompt": question, "history": history, "streaming": True}) + + for answer_result in answer_result_stream_result['answer_result_stream']: + resp = answer_result.llm_output["answer"] + history = answer_result.history + pass + return resp, history + loop = asyncio.get_event_loop() + executor = concurrent.futures.ThreadPoolExecutor() + resp, history = await loop.run_in_executor(executor, _sync_method, question, history) return ChatMessage( question=question, response=resp, @@ -399,7 +418,6 @@ async def chat( source_documents=[], ) - async def stream_chat(websocket: WebSocket): await websocket.accept() turn = 1 diff --git a/test/api/test_async.py b/test/api/test_async.py new file mode 100644 index 000000000..0a59e5210 --- /dev/null +++ b/test/api/test_async.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@DOC : 测试接口并发性 +@Date :2023/7/17 15:25 +""" + +import requests +import time +import concurrent.futures + +inputBody = { + "knowledge_base_id": None, + "question": "写一个python语言实现的二叉树demo", + # "question": "介绍一下自己?", + "history": [] +} +headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json' +} + + +def send_request(url): + t1 = time.time() + print("start:", url) + response = requests.post(url, json=inputBody, headers=headers) + t2 = time.time() + print("Time taken:", t2 - t1) + print("end:", url) + return response.text + + +url = 'http://localhost:7861/chat' +# host = 'localhost:7861/local_doc_qa/bing_search_chat' + + +with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [executor.submit(send_request, url) for _ in range(5)] + for future in concurrent.futures.as_completed(futures): + response = future.result() + print("Response:", response) From 119f5b003714eee8e1f9fdb6db3b02ba4539c14e Mon Sep 17 00:00:00 2001 From: aleimu Date: Tue, 18 Jul 2023 19:31:21 +0800 Subject: [PATCH 2/2] =?UTF-8?q?[fix]=E4=BD=BF=E7=94=A8=E5=BA=93=E8=87=AA?= =?UTF-8?q?=E5=B8=A6=E5=87=BD=E6=95=B0=E7=AE=80=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/api.py b/api.py index c71b106f1..13019ee9a 100644 --- a/api.py +++ b/api.py @@ -8,7 +8,6 @@ import nltk import pydantic import uvicorn -import concurrent.futures from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel @@ -299,7 +298,6 @@ async def update_doc( return BaseResponse(code=500, msg=file_status) - async def local_doc_chat( knowledge_base_id: str = Body(..., description="Knowledge Base Name", example="kb1"), question: str = Body(..., description="Question", example="工伤保险是什么?"), @@ -337,9 +335,7 @@ def _sync_method(question, history): ] return resp, source_documents - loop = asyncio.get_event_loop() - executor = concurrent.futures.ThreadPoolExecutor() - resp, source_documents = await loop.run_in_executor(executor, _sync_method, question, history) + resp, source_documents = await asyncio.to_thread(_sync_method, question=question, history=history) return ChatMessage( question=question, @@ -374,9 +370,7 @@ def _sync_method(question, history): ] return resp, source_documents - loop = asyncio.get_event_loop() - executor = concurrent.futures.ThreadPoolExecutor() - resp, source_documents = await loop.run_in_executor(executor, _sync_method, question, history) + resp, source_documents = await asyncio.to_thread(_sync_method, question=question, history=history) return ChatMessage( question=question, response=resp["result"], @@ -384,6 +378,7 @@ def _sync_method(question, history): source_documents=source_documents, ) + async def chat( question: str = Body(..., description="Question", example="工伤保险是什么?"), history: List[List[str]] = Body( @@ -408,9 +403,7 @@ def _sync_method(question, history): pass return resp, history - loop = asyncio.get_event_loop() - executor = concurrent.futures.ThreadPoolExecutor() - resp, history = await loop.run_in_executor(executor, _sync_method, question, history) + resp, history = await asyncio.to_thread(_sync_method, question=question, history=history) return ChatMessage( question=question, response=resp, @@ -418,6 +411,7 @@ def _sync_method(question, history): source_documents=[], ) + async def stream_chat(websocket: WebSocket): await websocket.accept() turn = 1 @@ -461,6 +455,7 @@ async def stream_chat(websocket: WebSocket): ) turn += 1 + async def stream_chat_bing(websocket: WebSocket): """ 基于bing搜索的流式问答 @@ -474,7 +469,8 @@ async def stream_chat_bing(websocket: WebSocket): await websocket.send_json({"question": question, "turn": turn, "flag": "start"}) last_print_len = 0 - for resp, history in local_doc_qa.get_search_result_based_answer(question, chat_history=history, streaming=True): + for resp, history in local_doc_qa.get_search_result_based_answer(question, chat_history=history, + streaming=True): await websocket.send_text(resp["result"][last_print_len:]) last_print_len = len(resp["result"]) @@ -497,6 +493,7 @@ async def stream_chat_bing(websocket: WebSocket): ) turn += 1 + async def document(): return RedirectResponse(url="/docs") @@ -540,7 +537,8 @@ def api_start(host, port, **kwargs): app.get("/local_doc_qa/list_files", response_model=ListDocsResponse, summary="获取知识库内的文件列表")(list_docs) app.delete("/local_doc_qa/delete_knowledge_base", response_model=BaseResponse, summary="删除知识库")(delete_kb) app.delete("/local_doc_qa/delete_file", response_model=BaseResponse, summary="删除知识库内的文件")(delete_doc) - app.post("/local_doc_qa/update_file", response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件")(update_doc) + app.post("/local_doc_qa/update_file", response_model=BaseResponse, summary="上传文件到知识库,并删除另一个文件")( + update_doc) local_doc_qa = LocalDocQA() local_doc_qa.init_cfg(