diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py index a34e83018..e7524f2a7 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_local_knowledgebase.py @@ -46,7 +46,7 @@ def search_knowledgebase(query: str, database: str, config: dict): @regist_tool(description=template_knowledge, title="本地知识库") def search_local_knowledgebase( - database: str = Field(description="Database for Knowledge Search", choices=list_kbs().data), + database: str = Field(description="Database for Knowledge Search", choices=[kb.kb_name for kb in list_kbs().data]), query: str = Field(description="Query for Knowledge Search"), ): '''''' diff --git a/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py b/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py index 67e8060f3..c22b71543 100644 --- a/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py +++ b/libs/chatchat-server/chatchat/server/db/models/knowledge_base_model.py @@ -1,5 +1,7 @@ from sqlalchemy import Column, Integer, String, DateTime, func - +from pydantic import BaseModel +from typing import Optional +from datetime import datetime from chatchat.server.db.base import Base @@ -18,3 +20,16 @@ class KnowledgeBaseModel(Base): def __repr__(self): return f"" + +# 创建一个对应的 Pydantic 模型 +class KnowledgeBaseSchema(BaseModel): + id: int + kb_name: str + kb_info: Optional[str] + vs_type: Optional[str] + embed_model: Optional[str] + file_count: Optional[int] + create_time: Optional[datetime] + + class Config: + from_attributes = True # 确保可以从 ORM 实例进行验证 \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py b/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py index 2231de8ce..785d95b94 100644 --- a/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py +++ b/libs/chatchat-server/chatchat/server/db/repository/knowledge_base_repository.py @@ -1,4 +1,5 @@ from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseModel +from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema from chatchat.server.db.session import with_session @@ -18,8 +19,8 @@ def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model): @with_session def list_kbs_from_db(session, min_file_count: int = -1): - kbs = session.query(KnowledgeBaseModel.kb_name).filter(KnowledgeBaseModel.file_count > min_file_count).all() - kbs = [kb[0] for kb in kbs] + kbs = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.file_count > min_file_count).all() + kbs = [KnowledgeBaseSchema.model_validate(kb) for kb in kbs] return kbs diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py index 05724cb14..7cdbf1022 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/kb_service/base.py @@ -25,7 +25,7 @@ from typing import List, Union, Dict, Optional, Tuple from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId - +from chatchat.server.db.models.knowledge_base_model import KnowledgeBaseSchema class SupportedVSType: FAISS = 'faiss' @@ -325,7 +325,7 @@ def get_default(): def get_kb_details() -> List[Dict]: kbs_in_folder = list_kbs_from_folder() - kbs_in_db = KBService.list_kbs() + kbs_in_db:List[KnowledgeBaseSchema] = KBService.list_kbs() result = {} for kb in kbs_in_folder: @@ -340,15 +340,16 @@ def get_kb_details() -> List[Dict]: "in_db": False, } - for kb in kbs_in_db: - kb_detail = get_kb_detail(kb) - if kb_detail: - kb_detail["in_db"] = True - if kb in result: - result[kb].update(kb_detail) - else: - kb_detail["in_folder"] = False - result[kb] = kb_detail + for kb_detail in kbs_in_db: + kb_detail=kb_detail.model_dump() + kb_name=kb_detail["kb_name"] + kb_detail["in_db"] = True + if kb_name in result: + result[kb_name].update(kb_detail) + else: + kb_detail["in_folder"] = False + result[kb_name] = kb_detail + data = [] for i, v in enumerate(result.values()): diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py index c8ae0980f..c5dd442bc 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py @@ -404,7 +404,7 @@ def files2docs_in_thread( except Exception as e: yield False, (kb_name, filename, str(e)) - for result in run_in_process_pool(func=files2docs_in_thread_file2docs, params=kwargs_list): + for result in run_in_thread_pool(func=files2docs_in_thread_file2docs, params=kwargs_list): yield result diff --git a/libs/chatchat-server/chatchat/server/utils.py b/libs/chatchat-server/chatchat/server/utils.py index 784e94d68..9fdcad527 100644 --- a/libs/chatchat-server/chatchat/server/utils.py +++ b/libs/chatchat-server/chatchat/server/utils.py @@ -302,7 +302,7 @@ class Config: class ListResponse(BaseResponse): - data: List[str] = Field(..., description="List of names") + data: List[Any] = Field(..., description="List of data") class Config: json_schema_extra = { diff --git a/libs/chatchat-server/pyproject.toml b/libs/chatchat-server/pyproject.toml index b5e305f48..b9b0d5fb7 100644 --- a/libs/chatchat-server/pyproject.toml +++ b/libs/chatchat-server/pyproject.toml @@ -62,53 +62,6 @@ xinference = ["xinference_client"] zhipuai = ["zhipuai"] cli = ["typer"] - -[tool.poetry.group.test] -optional = true - -[tool.poetry.group.test.dependencies] -# The only dependencies that should be added are -# dependencies used for running tests (e.g., pytest, freezegun, response). -# Any dependencies that do not meet that criteria will be removed. -pytest = "^7.3.0" -pytest-cov = "^4.0.0" -pytest-dotenv = "^0.5.2" -duckdb-engine = "^0.9.2" -pytest-watcher = "^0.2.6" -freezegun = "^1.2.2" -responses = "^0.22.0" -pytest-asyncio = "^0.23.2" -lark = "^1.1.5" -pytest-mock = "^3.10.0" -pytest-socket = "^0.6.0" -syrupy = "^4.0.2" -requests-mock = "^1.11.0" -model-providers = { path = "../model-providers", develop = true } - - -[tool.poetry.group.lint] -optional = true - -[tool.poetry.group.lint.dependencies] -ruff = "^0.1.5" - - -[tool.poetry.group.codespell] -optional = true - -[tool.poetry.group.codespell.dependencies] -codespell = "^2.2.0" - - -[tool.poetry.group.dev] -optional = true - -[tool.poetry.group.dev.dependencies] -jupyter = "^1.0.0" -setuptools = "^67.6.1" -model-providers = { path = "../model-providers", develop = true } - - # An extra used to be able to add extended testing. # Please use new-line on formatting to make it easier to add new packages without # merge-conflicts @@ -194,6 +147,53 @@ extended_testing = [ "friendli-client" ] +[tool.poetry.group.test] +optional = true + +[tool.poetry.group.test.dependencies] +# The only dependencies that should be added are +# dependencies used for running tests (e.g., pytest, freezegun, response). +# Any dependencies that do not meet that criteria will be removed. +pytest = "^7.3.0" +pytest-cov = "^4.0.0" +pytest-dotenv = "^0.5.2" +duckdb-engine = "^0.9.2" +pytest-watcher = "^0.2.6" +freezegun = "^1.2.2" +responses = "^0.22.0" +pytest-asyncio = "^0.23.2" +lark = "^1.1.5" +pytest-mock = "^3.10.0" +pytest-socket = "^0.6.0" +syrupy = "^4.0.2" +requests-mock = "^1.11.0" +model-providers = { path = "../model-providers", develop = true } + + +[tool.poetry.group.lint] +optional = true + +[tool.poetry.group.lint.dependencies] +ruff = "^0.1.5" + + +[tool.poetry.group.codespell] +optional = true + +[tool.poetry.group.codespell.dependencies] +codespell = "^2.2.0" + + +[tool.poetry.group.dev] +optional = true + +[tool.poetry.group.dev.dependencies] +jupyter = "^1.0.0" +setuptools = "^67.6.1" +model-providers = { path = "../model-providers", develop = true } + + + [tool.ruff] exclude = [ "tests/examples/non-utf8-encoding.py",