From d5e39c111c6f4656ca112ebe84671568a6d15d33 Mon Sep 17 00:00:00 2001 From: srszzw <741992282@qq.com> Date: Sat, 1 Jun 2024 18:28:24 +0800 Subject: [PATCH 1/3] =?UTF-8?q?1=E3=80=81=E4=BF=AE=E6=94=B9=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93=E5=88=97=E8=A1=A8=E6=8E=A5=E5=8F=A3=EF=BC=8C?= =?UTF-8?q?=E8=BF=94=E5=9B=9E=E5=85=A8=E9=87=8F=E5=B1=9E=E6=80=A7=E5=AD=97?= =?UTF-8?q?=E6=AE=B5=EF=BC=8C=E5=90=8C=E6=97=B6=E4=BF=AE=E6=94=B9=E5=8F=97?= =?UTF-8?q?=E5=BD=B1=E5=93=8D=E7=9A=84=E7=9B=B8=E5=85=B3=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E3=80=82=202=E3=80=81run=5Fin=5Fprocess=5Fpool=E6=94=B9?= =?UTF-8?q?=E4=B8=BArun=5Fin=5Fthread=5Fpool=EF=BC=8C=E8=A7=A3=E5=86=B3?= =?UTF-8?q?=E5=85=BC=E5=AE=B9=E6=80=A7=E9=97=AE=E9=A2=98=E3=80=82=203?= =?UTF-8?q?=E3=80=81poetry=E9=85=8D=E7=BD=AE=E6=96=87=E4=BB=B6=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../search_local_knowledgebase.py | 2 +- .../server/db/models/knowledge_base_model.py | 17 +++- .../repository/knowledge_base_repository.py | 5 +- .../server/knowledge_base/kb_service/base.py | 23 ++--- .../chatchat/server/knowledge_base/utils.py | 2 +- libs/chatchat-server/chatchat/server/utils.py | 2 +- libs/chatchat-server/pyproject.toml | 94 +++++++++---------- 7 files changed, 81 insertions(+), 64 deletions(-) diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index a34e83018..e7524f2a7 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -46,7 +46,7 @@ def search_knowledgebase(query: str, database: str, config: dict): @regist_tool(description=template_knowledge, title="本地知识库") def search_local_knowledgebase( - database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data), + database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]), query: str = Field(description="Query for Knowledge Search"), ): '''''' diff --git a/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py b/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py index 67e8060f3..c22b71543 100644 --- a/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py +++ b/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py @@ -1,5 +1,7 @@ from sqlalchemy import Column, Integer, String, DateTime, func - +from pydantic import BaseModel +from typing import Optional +from datetime import datetime from chatchat.server.db.base import Base @@ -18,3 +20,16 @@ class KnowledgeBaseModel(Base): def __repr__(self): return f"" + +# 创建一个对应的 Pydantic 模型 +class KnowledgeBaseSchema(BaseModel): + id: int + kb_name: str + kb_info: Optional[str] + vs_type: Optional[str] + embed_model: Optional[str] + file_count: Optional[int] + create_time: Optional[datetime] + + class Config: + from_attributes = True # 确保可以从 ORM 实例进行验证 \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py b/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py index 2231de8ce..785d95b94 100644 --- a/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py +++ b/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py @@ -1,4 +1,5 @@ from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseModel +from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema from chatchat.server.db.session import with_session @@ -18,8 +19,8 @@ def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model): @with_session def list_kbs_from_db(session, min_file_count: int = -1): - kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > min_file_count).all() - kbs = [kb[0] for kb in kbs] + kbs = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.file_count > min_file_count).all() + kbs = [KnowledgeBaseSchema.model_validate(kb) for kb in kbs] return kbs diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py index 05724cb14..7cdbf1022 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py @@ -25,7 +25,7 @@ from typing import List, Union, Dict, Optional, Tuple from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId - +from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema class SupportedVSType: FAISS = 'faiss' @@ -325,7 +325,7 @@ def get_default(): def get_kb_details() -> List[Dict]: kbs_in_folder = list_kbs_from_folder() - kbs_in_db = KBService.list_kbs() + kbs_in_db:List[KnowledgeBaseSchema] = KBService.list_kbs() result = {} for kb in kbs_in_folder: @@ -340,15 +340,16 @@ def get_kb_details() -> List[Dict]: "in_db": False, } - for kb in kbs_in_db: - kb_detail = get_kb_detail(kb) - if kb_detail: - kb_detail["in_db"] = True - if kb in result: - result[kb].update(kb_detail) - else: - kb_detail["in_folder"] = False - result[kb] = kb_detail + for kb_detail in kbs_in_db: + kb_detail=kb_detail.model_dump() + kb_name=kb_detail["kb_name"] + kb_detail["in_db"] = True + if kb_name in result: + result[kb_name].update(kb_detail) + else: + kb_detail["in_folder"] = False + result[kb_name] = kb_detail + data = [] for i, v in enumerate(result.values()): diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py index c8ae0980f..c5dd442bc 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py @@ -404,7 +404,7 @@ def files2docs_in_thread( except Exception as e: yield False, (kb_name, filename, str(e)) - for result in run_in_process_pool(func=files2docs_in_thread_file2docs, params=kwargs_list): + for result in run_in_thread_pool(func=files2docs_in_thread_file2docs, params=kwargs_list): yield result diff --git a/libs/chatchat-server/chatchat/server/utils.py b/libs/chatchat-server/chatchat/server/utils.py index 784e94d68..9fdcad527 100644 --- a/libs/chatchat-server/chatchat/server/utils.py +++ b/libs/chatchat-server/chatchat/server/utils.py @@ -302,7 +302,7 @@ class Config: class ListResponse(BaseResponse): - data: List[str] = Field(..., description="List of names") + data: List[Any] = Field(..., description="List of data") class Config: json_schema_extra = { diff --git a/libs/chatchat-server/pyproject.toml b/libs/chatchat-server/pyproject.toml index b5e305f48..b9b0d5fb7 100644 --- a/libs/chatchat-server/pyproject.toml +++ b/libs/chatchat-server/pyproject.toml @@ -62,53 +62,6 @@ xinference = ["xinference_client"] zhipuai = ["zhipuai"] cli = ["typer"] - -[tool.poetry.group.test] -optional = true - -[tool.poetry.group.test.dependencies] -# The only dependencies that should be added are -# dependencies used for running tests (e.g., pytest, freezegun, response). -# Any dependencies that do not meet that criteria will be removed. -pytest = "^7.3.0" -pytest-cov = "^4.0.0" -pytest-dotenv = "^0.5.2" -duckdb-engine = "^0.9.2" -pytest-watcher = "^0.2.6" -freezegun = "^1.2.2" -responses = "^0.22.0" -pytest-asyncio = "^0.23.2" -lark = "^1.1.5" -pytest-mock = "^3.10.0" -pytest-socket = "^0.6.0" -syrupy = "^4.0.2" -requests-mock = "^1.11.0" -model-providers = { path = "../model-providers", develop = true } - - -[tool.poetry.group.lint] -optional = true - -[tool.poetry.group.lint.dependencies] -ruff = "^0.1.5" - - -[tool.poetry.group.codespell] -optional = true - -[tool.poetry.group.codespell.dependencies] -codespell = "^2.2.0" - - -[tool.poetry.group.dev] -optional = true - -[tool.poetry.group.dev.dependencies] -jupyter = "^1.0.0" -setuptools = "^67.6.1" -model-providers = { path = "../model-providers", develop = true } - - # An extra used to be able to add extended testing. # Please use new-line on formatting to make it easier to add new packages without # merge-conflicts @@ -194,6 +147,53 @@ extended_testing = [ "friendli-client" ] +[tool.poetry.group.test] +optional = true + +[tool.poetry.group.test.dependencies] +# The only dependencies that should be added are +# dependencies used for running tests (e.g., pytest, freezegun, response). +# Any dependencies that do not meet that criteria will be removed. +pytest = "^7.3.0" +pytest-cov = "^4.0.0" +pytest-dotenv = "^0.5.2" +duckdb-engine = "^0.9.2" +pytest-watcher = "^0.2.6" +freezegun = "^1.2.2" +responses = "^0.22.0" +pytest-asyncio = "^0.23.2" +lark = "^1.1.5" +pytest-mock = "^3.10.0" +pytest-socket = "^0.6.0" +syrupy = "^4.0.2" +requests-mock = "^1.11.0" +model-providers = { path = "../model-providers", develop = true } + + +[tool.poetry.group.lint] +optional = true + +[tool.poetry.group.lint.dependencies] +ruff = "^0.1.5" + + +[tool.poetry.group.codespell] +optional = true + +[tool.poetry.group.codespell.dependencies] +codespell = "^2.2.0" + + +[tool.poetry.group.dev] +optional = true + +[tool.poetry.group.dev.dependencies] +jupyter = "^1.0.0" +setuptools = "^67.6.1" +model-providers = { path = "../model-providers", develop = true } + + + [tool.ruff] exclude = [ "tests/examples/non-utf8-encoding.py", From f67319a047bfaa12ecd215e64009dceed671106f Mon Sep 17 00:00:00 2001 From: srszzw <741992282@qq.com> Date: Sun, 2 Jun 2024 16:00:05 +0800 Subject: [PATCH 2/3] =?UTF-8?q?1=E3=80=81=E5=8A=A8=E6=80=81=E6=9B=B4?= =?UTF-8?q?=E6=96=B0Prompt=E4=B8=AD=E7=9A=84=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E6=8F=8F=E8=BF=B0=E4=BF=A1=E6=81=AF=EF=BC=8C=E4=BD=BF=E5=A4=A7?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=9B=B4=E5=AE=B9=E6=98=93=E5=88=A4=E6=96=AD?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E5=93=AA=E4=B8=AA=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- libs/chatchat-server/chatchat/server/utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/libs/chatchat-server/chatchat/server/utils.py b/libs/chatchat-server/chatchat/server/utils.py index 9fdcad527..a810c26fb 100644 --- a/libs/chatchat-server/chatchat/server/utils.py +++ b/libs/chatchat-server/chatchat/server/utils.py @@ -710,12 +710,29 @@ def get_temp_dir(id: str = None) -> Tuple[str, str]: return path, id +# 动态更新知识库信息 +def update_search_local_knowledgebase_tool(): + import re + from chatchat.server.agent.tools_factory import tools_registry + from chatchat.server.db.repository.knowledge_base_repository import list_kbs_from_db + kbs=list_kbs_from_db() + template = "Use local knowledgebase from one or more of these:\n{KB_info}\n to get information,Only local data on this knowledge use this tool. The 'database' should be one of the above [{key}]." + KB_info_str = '\n'.join([f"{kb.kb_name}: {kb.kb_info}" for kb in kbs]) + KB_name_info_str = '\n'.join([f"{kb.kb_name}" for kb in kbs]) + template_knowledge = template.format(KB_info=KB_info_str, key=KB_name_info_str) + + search_local_knowledgebase_tool=tools_registry._TOOLS_REGISTRY.get("search_local_knowledgebase") + if search_local_knowledgebase_tool: + search_local_knowledgebase_tool.description = " ".join(re.split(r"\n+\s*", template_knowledge)) + + def get_tool(name: str = None) -> Union[BaseTool, Dict[str, BaseTool]]: import importlib from chatchat.server.agent import tools_factory importlib.reload(tools_factory) from chatchat.server.agent.tools_factory import tools_registry + update_search_local_knowledgebase_tool() if name is None: return tools_registry._TOOLS_REGISTRY else: From 7e728c6d18571e8aa046d00826777a5d914a0354 Mon Sep 17 00:00:00 2001 From: srszzw <741992282@qq.com> Date: Mon, 3 Jun 2024 16:26:21 +0800 Subject: [PATCH 3/3] =?UTF-8?q?1=E3=80=81=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E7=9A=84=E4=B8=8B=E6=8B=89=E5=88=97=E8=A1=A8?= =?UTF-8?q?=E6=94=B9=E4=B8=BA=E5=8A=A8=E6=80=81=E8=8E=B7=E5=8F=96=EF=BC=8C?= =?UTF-8?q?=E4=B8=8D=E5=BF=85=E9=87=8D=E5=90=AF=E6=9C=8D=E5=8A=A1=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- libs/chatchat-server/chatchat/server/utils.py | 1 + libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/libs/chatchat-server/chatchat/server/utils.py b/libs/chatchat-server/chatchat/server/utils.py index a810c26fb..314ef0ccc 100644 --- a/libs/chatchat-server/chatchat/server/utils.py +++ b/libs/chatchat-server/chatchat/server/utils.py @@ -724,6 +724,7 @@ def update_search_local_knowledgebase_tool(): search_local_knowledgebase_tool=tools_registry._TOOLS_REGISTRY.get("search_local_knowledgebase") if search_local_knowledgebase_tool: search_local_knowledgebase_tool.description = " ".join(re.split(r"\n+\s*", template_knowledge)) + search_local_knowledgebase_tool.args["database"]["choices"]=[kb.kb_name for kb in kbs] def get_tool(name: str = None) -> Union[BaseTool, Dict[str, BaseTool]]: diff --git a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py index d211097e9..cdd40aa6a 100644 --- a/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py +++ b/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py @@ -106,7 +106,7 @@ def clear_conv(name: str = None): chat_box.reset_history(name=name or None) -@st.cache_data +# @st.cache_data def list_tools(_api: ApiRequest): return _api.list_tools()