Skip to content

Commit

Permalink
添加文件对话模式 (#2071)
Browse files Browse the repository at this point in the history
开发者:
- 添加 /chat/file_chat, /knowledge_base/upload_temp_docs API 接口
- 添加 CACHED_MEMO_VS_NUM, BASE_TEMP_DIR 配置项
  • Loading branch information
liunux4odoo committed Nov 15, 2023
1 parent 2adfa42 commit 3b3d948
Show file tree
Hide file tree
Showing 11 changed files with 339 additions and 10 deletions.
8 changes: 8 additions & 0 deletions configs/basic_config.py.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os
import langchain
import tempfile
import shutil


# 是否显示详细日志
Expand All @@ -23,3 +25,9 @@ logging.basicConfig(format=LOG_FORMAT)
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
if not os.path.exists(LOG_PATH):
os.mkdir(LOG_PATH)

# 临时文件目录,主要用于文件对话
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
if os.path.isdir(BASE_TEMP_DIR):
shutil.rmtree(BASE_TEMP_DIR)
os.makedirs(BASE_TEMP_DIR)
3 changes: 3 additions & 0 deletions configs/kb_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ DEFAULT_VS_TYPE = "faiss"
# 缓存向量库数量(针对FAISS)
CACHED_VS_NUM = 1

# 缓存临时向量库数量(针对FAISS),用于文件对话
CACHED_MEMO_VS_NUM = 10

# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
CHUNK_SIZE = 250

Expand Down
11 changes: 11 additions & 0 deletions server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def get_server_prompt_template(

def mount_knowledge_routes(app: FastAPI):
from server.chat.knowledge_base_chat import knowledge_base_chat
from server.chat.file_chat import upload_temp_docs, file_chat
from server.chat.agent_chat import agent_chat
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
Expand All @@ -152,6 +153,11 @@ def mount_knowledge_routes(app: FastAPI):
tags=["Chat"],
summary="与知识库对话")(knowledge_base_chat)

app.post("/chat/file_chat",
tags=["Knowledge Base Management"],
summary="文件对话"
)(file_chat)

app.post("/chat/agent_chat",
tags=["Chat"],
summary="与agent对话")(agent_chat)
Expand Down Expand Up @@ -218,6 +224,11 @@ def mount_knowledge_routes(app: FastAPI):
summary="根据content中文档重建向量库,流式输出处理进度。"
)(recreate_vector_store)

app.post("/knowledge_base/upload_temp_docs",
tags=["Knowledge Base Management"],
summary="上传文件到临时目录,用于文件对话。"
)(upload_temp_docs)


def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
Expand Down
167 changes: 167 additions & 0 deletions server/chat/file_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from fastapi import Body, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import (wrap_done, get_ChatOpenAI,
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.knowledge_base.utils import KnowledgeFile
import json
import os
from pathlib import Path


def _parse_files_in_thread(
files: List[UploadFile],
dir: str,
zh_title_enhance: bool,
chunk_size: int,
chunk_overlap: int,
):
"""
通过多线程将上传的文件保存到对应目录内。
生成器返回保存结果:[success or error, filename, msg, docs]
"""
def parse_file(file: UploadFile) -> dict:
'''
保存单个文件。
'''
try:
filename = file.filename
file_path = os.path.join(dir, filename)
file_content = file.file.read() # 读取上传文件的内容
with open(file_path, "wb") as f:
f.write(file_content)
kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp")
kb_file.filepath = file_path
docs = kb_file.file2text(zh_title_enhance=zh_title_enhance,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
return True, filename, f"成功上传文件 {filename}", docs
except Exception as e:
msg = f"{filename} 文件上传失败,报错信息为: {e}"
return False, filename, msg, []

params = [{"file": file} for file in files]
for result in run_in_thread_pool(parse_file, params=params):
yield result


def upload_temp_docs(
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
prev_id: str = Form(None, description="前知识库ID"),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
) -> BaseResponse:
'''
将文件保存到临时目录,并进行向量化。
返回临时目录名称作为ID,同时也是临时向量库的ID。
'''
if prev_id is not None:
memo_faiss_pool.pop(prev_id)

failed_files = []
documents = []
path, id = get_temp_dir(prev_id)
for success, file, msg, docs in _parse_files_in_thread(files=files,
dir=path,
zh_title_enhance=zh_title_enhance,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap):
if success:
documents += docs
else:
failed_files.append({file: msg})

with memo_faiss_pool.load_vector_store(id).acquire() as vs:
vs.add_documents(documents)
return BaseResponse(data={"id": id, "failed_files": failed_files})


async def file_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_id: str = Body(..., description="临时知识库ID"),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
if knowledge_id not in memo_faiss_pool.keys():
return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件")

history = [History.from_data(h) for h in history]

async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
embed_func = EmbeddingsFunAdapter()
embeddings = embed_func.embed_query(query)
with memo_faiss_pool.acquire(knowledge_id) as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
docs = [x[0] for x in docs]

context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: ## 如果没有找到相关文档,使用Empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "Empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])

chain = LLMChain(prompt=chat_prompt, llm=model)

# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)

source_documents = []
doc_path = get_temp_dir(knowledge_id)[0]
for inum, doc in enumerate(docs):
filename = Path(doc.metadata["source"]).resolve().relative_to(doc_path)
text = f"""出处 [{inum + 1}] [{filename}] \n\n{doc.page_content}\n\n"""
source_documents.append(text)

if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")

if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task

return StreamingResponse(knowledge_base_chat_iterator(), media_type="text/event-stream")
2 changes: 1 addition & 1 deletion server/db/repository/knowledge_file_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def list_docs_from_db(session,
'''
docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
if file_name:
docs = docs.filter_by(file_name=file_name)
docs = docs.filter(FileDocModel.file_name.ilike(file_name))
for k, v in metadata.items():
docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v))

Expand Down
3 changes: 2 additions & 1 deletion server/knowledge_base/kb_cache/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.faiss import FAISS
import threading
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
logger, log_verbose)
Expand All @@ -25,7 +26,7 @@ def key(self):
return self._key

@contextmanager
def acquire(self, owner: str = "", msg: str = ""):
def acquire(self, owner: str = "", msg: str = "") -> FAISS:
owner = owner or f"thread {threading.get_native_id()}"
try:
self._lock.acquire()
Expand Down
4 changes: 2 additions & 2 deletions server/knowledge_base/kb_cache/faiss_cache.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from configs import CACHED_VS_NUM
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
from server.knowledge_base.kb_cache.base import *
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.utils import load_local_embeddings
Expand Down Expand Up @@ -123,7 +123,7 @@ def load_vector_store(


kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
memo_faiss_pool = MemoFaissPool()
memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion server/knowledge_base/kb_doc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import get_file_detail
from typing import List
from langchain.docstore.document import Document
from typing import List


class DocumentWithScore(Document):
Expand Down
18 changes: 17 additions & 1 deletion server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
import httpx
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union, Tuple


async def wrap_done(fn: Awaitable, event: asyncio.Event):
Expand Down Expand Up @@ -700,3 +700,19 @@ def load_local_embeddings(model: str = None, device: str = embedding_device()):

model = model or EMBEDDING_MODEL
return embeddings_pool.load_embeddings(model=model, device=device)


def get_temp_dir(id: str = None) -> Tuple[str, str]:
'''
创建一个临时目录,返回(路径,文件夹名称)
'''
from configs.basic_config import BASE_TEMP_DIR
import tempfile

if id is not None: # 如果指定的临时目录已存在,直接返回
path = os.path.join(BASE_TEMP_DIR, id)
if os.path.isdir(path):
return path, id

path = tempfile.mkdtemp(dir=BASE_TEMP_DIR)
return path, os.path.basename(path)

0 comments on commit 3b3d948

Please sign in to comment.