Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

新功能:知识库管理界面支持查看、编辑、删除向量库文档 #2471

Merged
merged 2 commits into from Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 10 additions & 2 deletions server/api.py
Expand Up @@ -149,7 +149,8 @@ def mount_knowledge_routes(app: FastAPI):
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore, update_info)
search_docs, DocumentWithVSId, update_info,
update_docs_by_id,)

app.post("/chat/knowledge_base_chat",
tags=["Chat"],
Expand Down Expand Up @@ -190,10 +191,17 @@ def mount_knowledge_routes(app: FastAPI):

app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithScore],
response_model=List[DocumentWithVSId],
summary="搜索知识库"
)(search_docs)

app.post("/knowledge_base/update_docs_by_id",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="直接更新知识库文档"
)(update_docs_by_id)


app.post("/knowledge_base/upload_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
Expand Down
2 changes: 1 addition & 1 deletion server/chat/chat.py
Expand Up @@ -29,7 +29,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=2.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
Expand Down
10 changes: 5 additions & 5 deletions server/db/repository/knowledge_base_repository.py
Expand Up @@ -5,7 +5,7 @@
@with_session
def add_kb_to_db(session, kb_name, kb_info, vs_type, embed_model):
# 创建知识库实例
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
if not kb:
kb = KnowledgeBaseModel(kb_name=kb_name, kb_info=kb_info, vs_type=vs_type, embed_model=embed_model)
session.add(kb)
Expand All @@ -25,14 +25,14 @@ def list_kbs_from_db(session, min_file_count: int = -1):

@with_session
def kb_exists(session, kb_name):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
status = True if kb else False
return status


@with_session
def load_kb_from_db(session, kb_name):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
if kb:
kb_name, vs_type, embed_model = kb.kb_name, kb.vs_type, kb.embed_model
else:
Expand All @@ -42,15 +42,15 @@ def load_kb_from_db(session, kb_name):

@with_session
def delete_kb_from_db(session, kb_name):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
if kb:
session.delete(kb)
return True


@with_session
def get_kb_detail(session, kb_name: str) -> dict:
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_name).first()
kb: KnowledgeBaseModel = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_name)).first()
if kb:
return {
"kb_name": kb.kb_name,
Expand Down
41 changes: 23 additions & 18 deletions server/db/repository/knowledge_file_repository.py
Expand Up @@ -15,7 +15,7 @@ def list_docs_from_db(session,
列出某知识库某文件对应的所有Document。
返回形式:[{"id": str, "metadata": dict}, ...]
'''
docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
docs = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name))
if file_name:
docs = docs.filter(FileDocModel.file_name.ilike(file_name))
for k, v in metadata.items():
Expand All @@ -34,10 +34,10 @@ def delete_docs_from_db(session,
返回形式:[{"id": str, "metadata": dict}, ...]
'''
docs = list_docs_from_db(kb_name=kb_name, file_name=file_name)
query = session.query(FileDocModel).filter_by(kb_name=kb_name)
query = session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(kb_name))
if file_name:
query = query.filter_by(file_name=file_name)
query.delete()
query = query.filter(FileDocModel.file_name.ilike(file_name))
query.delete(synchronize_session=False)
session.commit()
return docs

Expand Down Expand Up @@ -68,12 +68,12 @@ def add_docs_to_db(session,

@with_session
def count_files_from_db(session, kb_name: str) -> int:
return session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).count()
return session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(kb_name)).count()


@with_session
def list_files_from_db(session, kb_name):
files = session.query(KnowledgeFileModel).filter_by(kb_name=kb_name).all()
files = session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(kb_name)).all()
docs = [f.file_name for f in files]
return docs

Expand All @@ -89,8 +89,8 @@ def add_file_to_db(session,
if kb:
# 如果已经存在该文件,则更新文件信息与版本号
existing_file: KnowledgeFileModel = (session.query(KnowledgeFileModel)
.filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name)
.filter(KnowledgeFileModel.kb_name.ilike(kb_file.kb_name),
KnowledgeFileModel.file_name.ilike(kb_file.filename))
.first())
mtime = kb_file.get_mtime()
size = kb_file.get_size()
Expand Down Expand Up @@ -122,14 +122,16 @@ def add_file_to_db(session,

@with_session
def delete_file_from_db(session, kb_file: KnowledgeFile):
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
existing_file = (session.query(KnowledgeFileModel)
.filter(KnowledgeFileModel.file_name.ilike(kb_file.filename),
KnowledgeFileModel.kb_name.ilike(kb_file.kb_name))
.first())
if existing_file:
session.delete(existing_file)
delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename)
session.commit()

kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(kb_file.kb_name)).first()
if kb:
kb.file_count -= 1
session.commit()
Expand All @@ -138,9 +140,9 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):

@with_session
def delete_files_from_db(session, knowledge_base_name: str):
session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete()
session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first()
session.query(KnowledgeFileModel).filter(KnowledgeFileModel.kb_name.ilike(knowledge_base_name)).delete(synchronize_session=False)
session.query(FileDocModel).filter(FileDocModel.kb_name.ilike(knowledge_base_name)).delete(synchronize_session=False)
kb = session.query(KnowledgeBaseModel).filter(KnowledgeBaseModel.kb_name.ilike(knowledge_base_name)).first()
if kb:
kb.file_count = 0

Expand All @@ -150,16 +152,19 @@ def delete_files_from_db(session, knowledge_base_name: str):

@with_session
def file_exists_in_db(session, kb_file: KnowledgeFile):
existing_file = session.query(KnowledgeFileModel).filter_by(file_name=kb_file.filename,
kb_name=kb_file.kb_name).first()
existing_file = (session.query(KnowledgeFileModel)
.filter(KnowledgeFileModel.file_name.ilike(kb_file.filename),
KnowledgeFileModel.kb_name.ilike(kb_file.kb_name))
.first())
return True if existing_file else False


@with_session
def get_file_detail(session, kb_name: str, filename: str) -> dict:
file: KnowledgeFileModel = (session.query(KnowledgeFileModel)
.filter_by(file_name=filename,
kb_name=kb_name).first())
.filter(KnowledgeFileModel.file_name.ilike(filename),
KnowledgeFileModel.kb_name.ilike(kb_name))
.first())
if file:
return {
"kb_name": file.kb_name,
Expand Down
8 changes: 4 additions & 4 deletions server/db/repository/knowledge_metadata_repository.py
Expand Up @@ -12,7 +12,7 @@ def list_summary_from_db(session,
列出某知识库chunk summary。
返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
'''
docs = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
docs = session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name))

for k, v in metadata.items():
docs = docs.filter(SummaryChunkModel.meta_data[k].as_string() == str(v))
Expand All @@ -33,8 +33,8 @@ def delete_summary_from_db(session,
返回形式:[{"id": str, "summary_context": str, "doc_ids": str}, ...]
'''
docs = list_summary_from_db(kb_name=kb_name)
query = session.query(SummaryChunkModel).filter_by(kb_name=kb_name)
query.delete()
query = session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name))
query.delete(synchronize_session=False)
session.commit()
return docs

Expand Down Expand Up @@ -63,4 +63,4 @@ def add_summary_to_db(session,

@with_session
def count_summary_from_db(session, kb_name: str) -> int:
return session.query(SummaryChunkModel).filter_by(kb_name=kb_name).count()
return session.query(SummaryChunkModel).filter(SummaryChunkModel.kb_name.ilike(kb_name)).count()
14 changes: 14 additions & 0 deletions server/knowledge_base/kb_cache/faiss_cache.py
Expand Up @@ -4,10 +4,24 @@
from server.utils import load_local_embeddings
from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores.faiss import FAISS
from langchain.docstore.in_memory import InMemoryDocstore
from langchain.schema import Document
import os
from langchain.schema import Document


# patch FAISS to include doc id in Document.metadata
def _new_ds_search(self, search: str) -> Union[str, Document]:
if search not in self._dict:
return f"ID {search} not found."
else:
doc = self._dict[search]
if isinstance(doc, Document):
doc.metadata["id"] = search
return doc
InMemoryDocstore.search = _new_ds_search


class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
Expand Down
40 changes: 29 additions & 11 deletions server/knowledge_base/kb_doc_api.py
Expand Up @@ -15,31 +15,49 @@
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import get_file_detail
from langchain.docstore.document import Document
from typing import List


class DocumentWithScore(Document):
score: float = None
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from typing import List, Dict


def search_docs(
query: str = Body(..., description="用户输入", examples=["你好"]),
query: str = Body("", description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD,
description="知识库匹配相关度阈值,取值范围在0-1之间,"
"SCORE越小,相关度越高,"
"取到1相当于不筛选,建议设置在0.5左右",
ge=0, le=1),
) -> List[DocumentWithScore]:
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
) -> List[DocumentWithVSId]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return []
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
data = []
if kb is not None:
if query:
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata)
return data


def update_docs_by_id(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
docs: Dict[str, Document] = Body(..., description="要更新的文档内容,形如:{id: Document, ...}")
) -> BaseResponse:
'''
按照文档 ID 更新文档内容
'''
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=500, msg=f"指定的知识库 {knowledge_base_name} 不存在")
if kb.update_doc_by_ids(docs=docs):
return BaseResponse(msg=f"文档更新成功")
else:
return BaseResponse(msg=f"文档更新失败")


def list_files(
knowledge_base_name: str
) -> ListResponse:
Expand Down
39 changes: 31 additions & 8 deletions server/knowledge_base/kb_service/base.py
Expand Up @@ -121,8 +121,9 @@ def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
for doc in docs:
try:
source = doc.metadata.get("source", "")
rel_path = Path(source).relative_to(self.doc_path)
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
if os.path.isabs(source):
rel_path = Path(source).relative_to(self.doc_path)
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
except Exception as e:
print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
self.delete_doc(kb_file)
Expand Down Expand Up @@ -176,24 +177,44 @@ def search_docs(self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
):
) ->List[Document]:
docs = self.do_search(query, top_k, score_threshold)
return docs

def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return []

def del_doc_by_ids(self, ids: List[str]) -> bool:
raise NotImplementedError

def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool:
'''
传入参数为: {doc_id: Document, ...}
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
TODO:是否要支持新增 docs ?
'''
self.del_doc_by_ids(list(docs.keys()))
docs = []
ids = []
for k, v in docs.items():
if not v or not v.page_content.strip():
continue
ids.append(k)
docs.append(v)
self.do_add_doc(docs=docs, ids=ids)
return True

def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
'''
通过file_name或metadata检索Document
'''
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
docs = []
for x in doc_infos:
doc_info_s = self.get_doc_by_ids([x["id"]])
if doc_info_s is not None and doc_info_s != []:
doc_info = self.get_doc_by_ids([x["id"]])[0]
if doc_info is not None:
# 处理非空的情况
doc_with_id = DocumentWithVSId(**doc_info_s[0].dict(), id=x["id"])
doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"])
docs.append(doc_with_id)
else:
# 处理空的情况
Expand Down Expand Up @@ -249,6 +270,7 @@ def do_search(self,
@abstractmethod
def do_add_doc(self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
"""
向知识库添加文档子类实自己逻辑
Expand Down Expand Up @@ -371,12 +393,13 @@ def get_kb_file_details(kb_name: str) -> List[Dict]:
"in_folder": True,
"in_db": False,
}
lower_names = {x.lower(): x for x in result}
for doc in files_in_db:
doc_detail = get_file_detail(kb_name, doc)
if doc_detail:
doc_detail["in_db"] = True
if doc in result:
result[doc].update(doc_detail)
if doc.lower() in lower_names:
result[lower_names[doc.lower()]].update(doc_detail)
else:
doc_detail["in_folder"] = False
result[doc] = doc_detail
Expand Down