Skip to content

Commit

Permalink
修复faiss_pool知识库缓存key错误 (#1507)
Browse files Browse the repository at this point in the history
  • Loading branch information
liunux4odoo committed Sep 17, 2023
1 parent ec85cd1 commit 1bae930
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 13 deletions.
12 changes: 8 additions & 4 deletions server/knowledge_base/kb_cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,25 @@ def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" =

def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}>"
return f"<{cls}: key: {self.key}, obj: {self._obj}>"

@property
def key(self):
return self._key

@contextmanager
def acquire(self, owner: str = "", msg: str = ""):
owner = owner or f"thread {threading.get_native_id()}"
try:
self._lock.acquire()
if self._pool is not None:
self._pool._cache.move_to_end(self._key)
self._pool._cache.move_to_end(self.key)
if log_verbose:
logger.info(f"{owner} 开始操作:{self._key}{msg}")
logger.info(f"{owner} 开始操作:{self.key}{msg}")
yield self._obj
finally:
if log_verbose:
logger.info(f"{owner} 结束操作:{self._key}{msg}")
logger.info(f"{owner} 结束操作:{self.key}{msg}")
self._lock.release()

def start_loading(self):
Expand Down
14 changes: 7 additions & 7 deletions server/knowledge_base/kb_cache/faiss_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"

def docs_count(self) -> int:
return len(self._obj.docstore._dict)
Expand All @@ -17,7 +17,7 @@ def save(self, path: str, create_path: bool = True):
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
logger.info(f"已将向量库 {self._key} 保存到磁盘")
logger.info(f"已将向量库 {self.key} 保存到磁盘")
return ret

def clear(self):
Expand All @@ -27,7 +27,7 @@ def clear(self):
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
logger.info(f"已将向量库 {self._key} 清空")
logger.info(f"已将向量库 {self.key} 清空")
return ret


Expand Down Expand Up @@ -66,10 +66,10 @@ def load_vector_store(
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name+vector_name)
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name+vector_name, item)
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}/{vector_name}' from disk.")
Expand All @@ -90,7 +90,7 @@ def load_vector_store(
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name+vector_name)
return self.get((kb_name, vector_name))


class MemoFaissPool(_FaissPool):
Expand Down
7 changes: 5 additions & 2 deletions server/knowledge_base/kb_service/faiss_kb_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,21 @@
class FaissKBService(KBService):
vs_path: str
kb_path: str
vector_name: str = "vector_store"

def vs_type(self) -> str:
return SupportedVSType.FAISS

def get_vs_path(self):
return os.path.join(self.get_kb_path(), "vector_store")
return os.path.join(self.get_kb_path(), self.vector_name)

def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)

def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model)
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
vector_name=self.vector_name,
embed_model=self.embed_model)

def save_vector_store(self):
self.load_vector_store().save(self.vs_path)
Expand Down

0 comments on commit 1bae930

Please sign in to comment.