From 1bae93069161f1ea89ddda7c08fb168f402ea60b Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Sun, 17 Sep 2023 16:31:44 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dfaiss=5Fpool=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=BA=93=E7=BC=93=E5=AD=98key=E9=94=99=E8=AF=AF=20(#1?= =?UTF-8?q?507)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/knowledge_base/kb_cache/base.py | 12 ++++++++---- server/knowledge_base/kb_cache/faiss_cache.py | 14 +++++++------- .../knowledge_base/kb_service/faiss_kb_service.py | 7 +++++-- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/server/knowledge_base/kb_cache/base.py b/server/knowledge_base/kb_cache/base.py index d6b72c7be..cd60aa439 100644 --- a/server/knowledge_base/kb_cache/base.py +++ b/server/knowledge_base/kb_cache/base.py @@ -22,7 +22,11 @@ def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = def __repr__(self) -> str: cls = type(self).__name__ - return f"<{cls}: key: {self._key}, obj: {self._obj}>" + return f"<{cls}: key: {self.key}, obj: {self._obj}>" + + @property + def key(self): + return self._key @contextmanager def acquire(self, owner: str = "", msg: str = ""): @@ -30,13 +34,13 @@ def acquire(self, owner: str = "", msg: str = ""): try: self._lock.acquire() if self._pool is not None: - self._pool._cache.move_to_end(self._key) + self._pool._cache.move_to_end(self.key) if log_verbose: - logger.info(f"{owner} 开始操作:{self._key}。{msg}") + logger.info(f"{owner} 开始操作:{self.key}。{msg}") yield self._obj finally: if log_verbose: - logger.info(f"{owner} 结束操作:{self._key}。{msg}") + logger.info(f"{owner} 结束操作:{self.key}。{msg}") self._lock.release() def start_loading(self): diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index 8cc3c3113..801e4a6b2 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -7,7 +7,7 @@ class ThreadSafeFaiss(ThreadSafeObject): def __repr__(self) -> str: cls = type(self).__name__ - return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>" + return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>" def docs_count(self) -> int: return len(self._obj.docstore._dict) @@ -17,7 +17,7 @@ def save(self, path: str, create_path: bool = True): if not os.path.isdir(path) and create_path: os.makedirs(path) ret = self._obj.save_local(path) - logger.info(f"已将向量库 {self._key} 保存到磁盘") + logger.info(f"已将向量库 {self.key} 保存到磁盘") return ret def clear(self): @@ -27,7 +27,7 @@ def clear(self): if ids: ret = self._obj.delete(ids) assert len(self._obj.docstore._dict) == 0 - logger.info(f"已将向量库 {self._key} 清空") + logger.info(f"已将向量库 {self.key} 清空") return ret @@ -66,10 +66,10 @@ def load_vector_store( embed_device: str = embedding_device(), ) -> ThreadSafeFaiss: self.atomic.acquire() - cache = self.get(kb_name+vector_name) + cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些 if cache is None: - item = ThreadSafeFaiss(kb_name, pool=self) - self.set(kb_name+vector_name, item) + item = ThreadSafeFaiss((kb_name, vector_name), pool=self) + self.set((kb_name, vector_name), item) with item.acquire(msg="初始化"): self.atomic.release() logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.") @@ -90,7 +90,7 @@ def load_vector_store( item.finish_loading() else: self.atomic.release() - return self.get(kb_name+vector_name) + return self.get((kb_name, vector_name)) class MemoFaissPool(_FaissPool): diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 2671f55c0..e37c93ebb 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -18,18 +18,21 @@ class FaissKBService(KBService): vs_path: str kb_path: str + vector_name: str = "vector_store" def vs_type(self) -> str: return SupportedVSType.FAISS def get_vs_path(self): - return os.path.join(self.get_kb_path(), "vector_store") + return os.path.join(self.get_kb_path(), self.vector_name) def get_kb_path(self): return os.path.join(KB_ROOT_PATH, self.kb_name) def load_vector_store(self) -> ThreadSafeFaiss: - return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model) + return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, + vector_name=self.vector_name, + embed_model=self.embed_model) def save_vector_store(self): self.load_vector_store().save(self.vs_path)