Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] python安装版切换知识库后,向量搜索无结果 #401

Open
2 tasks done
xuzhenjun130 opened this issue Jun 14, 2024 · 6 comments
Open
2 tasks done

Comments

@xuzhenjun130
Copy link

是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this?

  • 我已经搜索过已有的issues和讨论 | I have searched the existing issues / discussions

该问题是否在FAQ中有解答? | Is there an existing answer for this in FAQ?

  • 我已经搜索过FAQ | I have searched FAQ

当前行为 | Current Behavior

刚启动,问答ok
只要切换问答库,就向量搜索就没有结果

期望行为 | Expected Behavior

能正常切换问答库

运行环境 | Environment

- OS: centos 7
- NVIDIA Driver: 535.161.07
- CUDA: 12.2
- NVIDIA GPU: rtx 4090
- NVIDIA GPU Memory: 32G

QAnything日志 | QAnything logs

No response

复现方法 | Steps To Reproduce

No response

备注 | Anything else?

qanything_kernel/connector/database/faiss/faiss_client.py

改用一次加载全部,后面逐个搜索,搜索完毕后再合并就正常了。

from langchain_community.vectorstores import FAISS
from langchain_community.docstore import InMemoryDocstore
from langchain_core.documents import Document
from qanything_kernel.configs.model_config import VECTOR_SEARCH_TOP_K, FAISS_LOCATION, FAISS_CACHE_SIZE
from typing import Optional, Union, Callable, Dict, Any, List, Tuple
from langchain_community.vectorstores.faiss import dependable_faiss_import
from qanything_kernel.utils.custom_log import debug_logger
from qanything_kernel.connector.database.mysql.mysql_client import KnowledgeBaseManager
from qanything_kernel.utils.general_utils import num_tokens
from functools import lru_cache
import shutil
import stat
import os
import platform

os_system = platform.system()

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # 可能是由于是MacOS系统的原因


class SelfInMemoryDocstore(InMemoryDocstore):
    def add(self, texts: Dict[str, Document]) -> None:
        """Add texts to in memory dictionary.

        Args:
            texts: dictionary of id -> document.

        Returns:
            None
        """
        self._dict.update(texts)


@lru_cache(FAISS_CACHE_SIZE)
def load_vector_store(faiss_index_path, embeddings):
    debug_logger.info(f'load faiss index: {faiss_index_path}')
    return FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)


class FaissClient:
    def __init__(self, mysql_client: KnowledgeBaseManager, embeddings):
        self.mysql_client: KnowledgeBaseManager = mysql_client
        self.embeddings = embeddings
        self.faiss_clients: Dict[str, FAISS] = {}  # 存储不同 kb_id 对应的 FAISS 客户端

    def _load_all_kbs_to_memory(self):
        for kb_id in os.listdir(FAISS_LOCATION):
            faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
            if os.path.exists(faiss_index_path):
                faiss_client: FAISS = load_vector_store(faiss_index_path, self.embeddings)
            else:
                faiss = dependable_faiss_import()
                index = faiss.IndexFlatL2(768)
                docstore = SelfInMemoryDocstore()
                debug_logger.info(f'init FAISS kb_id: {kb_id}')
                faiss_client: FAISS = FAISS(self.embeddings, index, docstore, index_to_docstore_id={})
            self.faiss_clients[kb_id] = faiss_client
        debug_logger.info(f'FAISS loaded all kb_ids')

    async def search(self, kb_ids, query, filter: Optional[Union[Callable, Dict[str, Any]]] = None,
                     top_k=VECTOR_SEARCH_TOP_K):
        if not self.faiss_clients:
            self._load_all_kbs_to_memory()

        all_docs_with_score = []
        for kb_id in kb_ids:
            faiss_client = self.faiss_clients.get(kb_id)
            if not faiss_client:
                continue

            if filter is None:
                filter = {}
            debug_logger.info(f'FAISS search: {query}, {filter}, {top_k} for kb_id: {kb_id}')
            docs_with_score = await faiss_client.asimilarity_search_with_score(query, k=top_k, filter=filter,
                                                                               fetch_k=200)
            all_docs_with_score.extend(docs_with_score)

        all_docs_with_score.sort(key=lambda x: x[1])  # 按照分数排序
        merged_docs_with_score = self.merge_docs(all_docs_with_score[:top_k])  # 只保留前 top_k 个结果
        return merged_docs_with_score

    def merge_docs(self, docs_with_score):
        merged_docs = []
        docs_with_score = sorted(docs_with_score, key=lambda x: (x[0].metadata['file_id'], x[0].metadata['chunk_id']))
        for doc, score in docs_with_score:
            doc.metadata['score'] = score
            if not merged_docs or merged_docs[-1].metadata['file_id'] != doc.metadata['file_id']:
                merged_docs.append(doc)
            else:
                if merged_docs[-1].metadata['chunk_id'] == doc.metadata['chunk_id'] - 1:
                    if num_tokens(merged_docs[-1].page_content + doc.page_content) <= 800:
                        merged_docs[-1].page_content += '\n' + doc.page_content
                        merged_docs[-1].metadata['chunk_id'] = doc.metadata['chunk_id']
                    else:
                        merged_docs.append(doc)
                else:
                    merged_docs.append(doc)
        return merged_docs

    async def add_document(self, docs):
        kb_id = docs[0].metadata['kb_id']
        if kb_id not in self.faiss_clients:
            self._load_all_kbs_to_memory()
        faiss_client = self.faiss_clients.get(kb_id)

        if not faiss_client:
            raise ValueError(f"KB with id {kb_id} not found")

        add_ids = await faiss_client.aadd_documents(docs)
        chunk_id = 0
        for doc, add_id in zip(docs, add_ids):
            self.mysql_client.add_document(add_id, chunk_id, doc.metadata['file_id'], doc.metadata['file_name'],
                                           doc.metadata['kb_id'])
            chunk_id += 1

        debug_logger.info(f'add documents number: {len(add_ids)}')
        faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
        faiss_client.save_local(faiss_index_path)
        debug_logger.info(f'save faiss index: {faiss_index_path}')
        os.chmod(os.path.dirname(faiss_index_path), stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
        return add_ids

    def delete_documents(self, kb_id, file_ids=None):
        if kb_id not in self.faiss_clients:
            self._load_all_kbs_to_memory()
        faiss_client = self.faiss_clients.get(kb_id)

        if not faiss_client:
            raise ValueError(f"KB with id {kb_id} not found")

        if file_ids is None:
            kb_index_path = os.path.join(FAISS_LOCATION, kb_id)
            if os.path.exists(kb_index_path):
                shutil.rmtree(kb_index_path)
                del self.faiss_clients[kb_id]
                debug_logger.info(f'delete kb_id: {kb_id}, {kb_index_path}')
                return
        else:
            doc_ids = self.mysql_client.get_documents_by_file_ids(file_ids)
            doc_ids = [doc_id[0] for doc_id in doc_ids]

        if not doc_ids:
            debug_logger.info(f'no documents to delete')
            return

        try:
            res = faiss_client.delete(doc_ids)
            debug_logger.info(f'delete documents: {res}')
            faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
            faiss_client.save_local(faiss_index_path)
            debug_logger.info(f'save faiss index: {faiss_index_path}')
            os.chmod(os.path.dirname(faiss_index_path), stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
        except ValueError as e:
            debug_logger.warning(f'delete documents not find docs')
@jiumar19
Copy link

这个问题的原因就是他faiss_client一个单例,里面的FAISS 反复在merge,造成不同kb之间、同一个kb的不同版本之间互相污染,要么得用你这种分索引搜索再merge结果,但是失去了top-k的意义了,要么就是针对bot去做FAISS的缓存,在缓存里面merge索引,这开源出来的py代码乱七八糟的,质量很堪忧

@jiumar19
Copy link

而且从代码上看,FAISS的落盘他是完全没考虑并发的情况,建议自己加锁以免出现各种神奇现象

@lycfight
Copy link

我用的qanything-python分支部署并发。可能存在多个同事测试上传文档的情况。我在测试的时候也遇到检索召回文档为0的情况,debug到faiss_client.asimilarity_search_with_score这里,即使query是从文档里的一段完全复制过来的,就是检索不出文档

@lycfight
Copy link

是否已有关于该错误的issue或讨论? | Is there an existing issue / discussion for this?

  • 我已经搜索过已有的issues和讨论 | I have searched the existing issues / discussions

该问题是否在FAQ中有解答? | Is there an existing answer for this in FAQ?

  • 我已经搜索过FAQ | I have searched FAQ

当前行为 | Current Behavior

刚启动,问答ok 只要切换问答库,就向量搜索就没有结果

期望行为 | Expected Behavior

能正常切换问答库

运行环境 | Environment

- OS: centos 7
- NVIDIA Driver: 535.161.07
- CUDA: 12.2
- NVIDIA GPU: rtx 4090
- NVIDIA GPU Memory: 32G

QAnything日志 | QAnything logs

No response

复现方法 | Steps To Reproduce

No response

备注 | Anything else?

qanything_kernel/connector/database/faiss/faiss_client.py

改用一次加载全部,后面逐个搜索,搜索完毕后再合并就正常了。

from langchain_community.vectorstores import FAISS
from langchain_community.docstore import InMemoryDocstore
from langchain_core.documents import Document
from qanything_kernel.configs.model_config import VECTOR_SEARCH_TOP_K, FAISS_LOCATION, FAISS_CACHE_SIZE
from typing import Optional, Union, Callable, Dict, Any, List, Tuple
from langchain_community.vectorstores.faiss import dependable_faiss_import
from qanything_kernel.utils.custom_log import debug_logger
from qanything_kernel.connector.database.mysql.mysql_client import KnowledgeBaseManager
from qanything_kernel.utils.general_utils import num_tokens
from functools import lru_cache
import shutil
import stat
import os
import platform

os_system = platform.system()

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # 可能是由于是MacOS系统的原因


class SelfInMemoryDocstore(InMemoryDocstore):
    def add(self, texts: Dict[str, Document]) -> None:
        """Add texts to in memory dictionary.

        Args:
            texts: dictionary of id -> document.

        Returns:
            None
        """
        self._dict.update(texts)


@lru_cache(FAISS_CACHE_SIZE)
def load_vector_store(faiss_index_path, embeddings):
    debug_logger.info(f'load faiss index: {faiss_index_path}')
    return FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)


class FaissClient:
    def __init__(self, mysql_client: KnowledgeBaseManager, embeddings):
        self.mysql_client: KnowledgeBaseManager = mysql_client
        self.embeddings = embeddings
        self.faiss_clients: Dict[str, FAISS] = {}  # 存储不同 kb_id 对应的 FAISS 客户端

    def _load_all_kbs_to_memory(self):
        for kb_id in os.listdir(FAISS_LOCATION):
            faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
            if os.path.exists(faiss_index_path):
                faiss_client: FAISS = load_vector_store(faiss_index_path, self.embeddings)
            else:
                faiss = dependable_faiss_import()
                index = faiss.IndexFlatL2(768)
                docstore = SelfInMemoryDocstore()
                debug_logger.info(f'init FAISS kb_id: {kb_id}')
                faiss_client: FAISS = FAISS(self.embeddings, index, docstore, index_to_docstore_id={})
            self.faiss_clients[kb_id] = faiss_client
        debug_logger.info(f'FAISS loaded all kb_ids')

    async def search(self, kb_ids, query, filter: Optional[Union[Callable, Dict[str, Any]]] = None,
                     top_k=VECTOR_SEARCH_TOP_K):
        if not self.faiss_clients:
            self._load_all_kbs_to_memory()

        all_docs_with_score = []
        for kb_id in kb_ids:
            faiss_client = self.faiss_clients.get(kb_id)
            if not faiss_client:
                continue

            if filter is None:
                filter = {}
            debug_logger.info(f'FAISS search: {query}, {filter}, {top_k} for kb_id: {kb_id}')
            docs_with_score = await faiss_client.asimilarity_search_with_score(query, k=top_k, filter=filter,
                                                                               fetch_k=200)
            all_docs_with_score.extend(docs_with_score)

        all_docs_with_score.sort(key=lambda x: x[1])  # 按照分数排序
        merged_docs_with_score = self.merge_docs(all_docs_with_score[:top_k])  # 只保留前 top_k 个结果
        return merged_docs_with_score

    def merge_docs(self, docs_with_score):
        merged_docs = []
        docs_with_score = sorted(docs_with_score, key=lambda x: (x[0].metadata['file_id'], x[0].metadata['chunk_id']))
        for doc, score in docs_with_score:
            doc.metadata['score'] = score
            if not merged_docs or merged_docs[-1].metadata['file_id'] != doc.metadata['file_id']:
                merged_docs.append(doc)
            else:
                if merged_docs[-1].metadata['chunk_id'] == doc.metadata['chunk_id'] - 1:
                    if num_tokens(merged_docs[-1].page_content + doc.page_content) <= 800:
                        merged_docs[-1].page_content += '\n' + doc.page_content
                        merged_docs[-1].metadata['chunk_id'] = doc.metadata['chunk_id']
                    else:
                        merged_docs.append(doc)
                else:
                    merged_docs.append(doc)
        return merged_docs

    async def add_document(self, docs):
        kb_id = docs[0].metadata['kb_id']
        if kb_id not in self.faiss_clients:
            self._load_all_kbs_to_memory()
        faiss_client = self.faiss_clients.get(kb_id)

        if not faiss_client:
            raise ValueError(f"KB with id {kb_id} not found")

        add_ids = await faiss_client.aadd_documents(docs)
        chunk_id = 0
        for doc, add_id in zip(docs, add_ids):
            self.mysql_client.add_document(add_id, chunk_id, doc.metadata['file_id'], doc.metadata['file_name'],
                                           doc.metadata['kb_id'])
            chunk_id += 1

        debug_logger.info(f'add documents number: {len(add_ids)}')
        faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
        faiss_client.save_local(faiss_index_path)
        debug_logger.info(f'save faiss index: {faiss_index_path}')
        os.chmod(os.path.dirname(faiss_index_path), stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
        return add_ids

    def delete_documents(self, kb_id, file_ids=None):
        if kb_id not in self.faiss_clients:
            self._load_all_kbs_to_memory()
        faiss_client = self.faiss_clients.get(kb_id)

        if not faiss_client:
            raise ValueError(f"KB with id {kb_id} not found")

        if file_ids is None:
            kb_index_path = os.path.join(FAISS_LOCATION, kb_id)
            if os.path.exists(kb_index_path):
                shutil.rmtree(kb_index_path)
                del self.faiss_clients[kb_id]
                debug_logger.info(f'delete kb_id: {kb_id}, {kb_index_path}')
                return
        else:
            doc_ids = self.mysql_client.get_documents_by_file_ids(file_ids)
            doc_ids = [doc_id[0] for doc_id in doc_ids]

        if not doc_ids:
            debug_logger.info(f'no documents to delete')
            return

        try:
            res = faiss_client.delete(doc_ids)
            debug_logger.info(f'delete documents: {res}')
            faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
            faiss_client.save_local(faiss_index_path)
            debug_logger.info(f'save faiss index: {faiss_index_path}')
            os.chmod(os.path.dirname(faiss_index_path), stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
        except ValueError as e:
            debug_logger.warning(f'delete documents not find docs')

这个代码在新创建知识库上传文件的时候会报错找不到路径:
def _load_all_kbs_to_memory(self):
for kb_id in os.listdir(FAISS_LOCATION):

@xuzhenjun130
Copy link
Author

@lycfight
代码仅供参考,第一版改的的确有点问题
这是第二版代码

from langchain_community.vectorstores import FAISS
from langchain_community.docstore import InMemoryDocstore
from langchain_core.documents import Document
from qanything_kernel.configs.model_config import VECTOR_SEARCH_TOP_K, FAISS_LOCATION, FAISS_CACHE_SIZE
from typing import Optional, Union, Callable, Dict, Any, List, Tuple
from langchain_community.vectorstores.faiss import dependable_faiss_import
from qanything_kernel.utils.custom_log import debug_logger
from qanything_kernel.connector.database.mysql.mysql_client import KnowledgeBaseManager
from qanything_kernel.utils.general_utils import num_tokens
from functools import lru_cache
import shutil
import stat
import os
import platform

os_system = platform.system()

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # 可能是由于是MacOS系统的原因


class SelfInMemoryDocstore(InMemoryDocstore):
    def add(self, texts: Dict[str, Document]) -> None:
        """Add texts to in memory dictionary.

        Args:
            texts: dictionary of id -> document.

        Returns:
            None
        """
        self._dict.update(texts)


@lru_cache(FAISS_CACHE_SIZE)
def load_vector_store(faiss_index_path, embeddings):
    debug_logger.info(f'load faiss index: {faiss_index_path}')
    return FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)


class FaissClient:
    def __init__(self, mysql_client: KnowledgeBaseManager, embeddings):
        self.mysql_client: KnowledgeBaseManager = mysql_client
        self.embeddings = embeddings
        self.faiss_clients: Dict[str, FAISS] = {}  # 存储不同 kb_id 对应的 FAISS 客户端

    def _load_all_kbs_to_memory(self):
        for kb_id in os.listdir(FAISS_LOCATION):
            faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
            debug_logger.info(f'FAISS loaded kb_id {kb_id} - {faiss_index_path}')
            if os.path.exists(faiss_index_path):
                faiss_client: FAISS = load_vector_store(faiss_index_path, self.embeddings)
                self.faiss_clients[kb_id] = faiss_client
        debug_logger.info(f'FAISS loaded all kb_ids')

    def _init_faiss_client(self, kb_id):
        faiss = dependable_faiss_import()
        index = faiss.IndexFlatL2(768)
        docstore = SelfInMemoryDocstore()
        debug_logger.info(f'init FAISS kb_id: {kb_id}')
        faiss_client: FAISS = FAISS(self.embeddings, index, docstore, index_to_docstore_id={})
        self.faiss_clients[kb_id] = faiss_client


    async def search(self, kb_ids, query, filter: Optional[Union[Callable, Dict[str, Any]]] = None,
                     top_k=VECTOR_SEARCH_TOP_K):
        if not self.faiss_clients:
            self._load_all_kbs_to_memory()

        all_docs_with_score = []
        for kb_id in kb_ids:
            faiss_client = self.faiss_clients.get(kb_id)
            if not faiss_client:
                continue

            if filter is None:
                filter = {}
            debug_logger.info(f'FAISS search: {query}, {filter}, {top_k} for kb_id: {kb_id}')
            docs_with_score = await faiss_client.asimilarity_search_with_score(query, k=top_k, filter=filter,
                                                                               fetch_k=200)
            all_docs_with_score.extend(docs_with_score)

        all_docs_with_score.sort(key=lambda x: x[1])  # 按照分数排序
        merged_docs_with_score = self.merge_docs(all_docs_with_score[:top_k])  # 只保留前 top_k 个结果
        return merged_docs_with_score

    def merge_docs(self, docs_with_score):
        merged_docs = []
        docs_with_score = sorted(docs_with_score, key=lambda x: (x[0].metadata['file_id'], x[0].metadata['chunk_id']))
        for doc, score in docs_with_score:
            doc.metadata['score'] = score
            if not merged_docs or merged_docs[-1].metadata['file_id'] != doc.metadata['file_id']:
                merged_docs.append(doc)
            else:
                if merged_docs[-1].metadata['chunk_id'] == doc.metadata['chunk_id'] - 1:
                    if num_tokens(merged_docs[-1].page_content + doc.page_content) <= 800:
                        merged_docs[-1].page_content += '\n' + doc.page_content
                        merged_docs[-1].metadata['chunk_id'] = doc.metadata['chunk_id']
                    else:
                        merged_docs.append(doc)
                else:
                    merged_docs.append(doc)
        return merged_docs

    async def add_document(self, docs):
        kb_id = docs[0].metadata['kb_id']
        if kb_id not in self.faiss_clients:
            self._init_faiss_client(kb_id)
        faiss_client = self.faiss_clients.get(kb_id)

        if not faiss_client:
            raise ValueError(f"KB with id {kb_id} not found")

        add_ids = await faiss_client.aadd_documents(docs)
        chunk_id = 0
        for doc, add_id in zip(docs, add_ids):
            self.mysql_client.add_document(add_id, chunk_id, doc.metadata['file_id'], doc.metadata['file_name'],
                                           doc.metadata['kb_id'])
            chunk_id += 1

        debug_logger.info(f'add documents number: {len(add_ids)}')
        faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
        faiss_client.save_local(faiss_index_path)
        debug_logger.info(f'save faiss index: {faiss_index_path}')
        os.chmod(os.path.dirname(faiss_index_path), stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
        return add_ids

    def delete_documents(self, kb_id, file_ids=None):
        if kb_id not in self.faiss_clients:
            self._load_all_kbs_to_memory()
        faiss_client = self.faiss_clients.get(kb_id)

        if not faiss_client:
            # 如果文档没有解析完成,是没有向量的
            debug_logger.info(f"KB with id {kb_id} not found")
            return

        if file_ids is None:
            kb_index_path = os.path.join(FAISS_LOCATION, kb_id)
            if os.path.exists(kb_index_path):
                shutil.rmtree(kb_index_path)
                del self.faiss_clients[kb_id]
                debug_logger.info(f'delete kb_id: {kb_id}, {kb_index_path}')
                return
        else:
            doc_ids = self.mysql_client.get_documents_by_file_ids(file_ids)
            doc_ids = [doc_id[0] for doc_id in doc_ids]

        if not doc_ids:
            debug_logger.info(f'no documents to delete')
            return

        try:
            res = faiss_client.delete(doc_ids)
            debug_logger.info(f'delete documents: {res}')
            faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
            faiss_client.save_local(faiss_index_path)
            debug_logger.info(f'save faiss index: {faiss_index_path}')
            os.chmod(os.path.dirname(faiss_index_path), stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
        except ValueError as e:
            debug_logger.warning(f'delete documents not find docs')

@lycfight
Copy link

lycfight commented Aug 6, 2024

@lycfight 代码仅供参考,第一版改的的确有点问题 这是第二版代码

from langchain_community.vectorstores import FAISS
from langchain_community.docstore import InMemoryDocstore
from langchain_core.documents import Document
from qanything_kernel.configs.model_config import VECTOR_SEARCH_TOP_K, FAISS_LOCATION, FAISS_CACHE_SIZE
from typing import Optional, Union, Callable, Dict, Any, List, Tuple
from langchain_community.vectorstores.faiss import dependable_faiss_import
from qanything_kernel.utils.custom_log import debug_logger
from qanything_kernel.connector.database.mysql.mysql_client import KnowledgeBaseManager
from qanything_kernel.utils.general_utils import num_tokens
from functools import lru_cache
import shutil
import stat
import os
import platform

os_system = platform.system()

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # 可能是由于是MacOS系统的原因


class SelfInMemoryDocstore(InMemoryDocstore):
    def add(self, texts: Dict[str, Document]) -> None:
        """Add texts to in memory dictionary.

        Args:
            texts: dictionary of id -> document.

        Returns:
            None
        """
        self._dict.update(texts)


@lru_cache(FAISS_CACHE_SIZE)
def load_vector_store(faiss_index_path, embeddings):
    debug_logger.info(f'load faiss index: {faiss_index_path}')
    return FAISS.load_local(faiss_index_path, embeddings, allow_dangerous_deserialization=True)


class FaissClient:
    def __init__(self, mysql_client: KnowledgeBaseManager, embeddings):
        self.mysql_client: KnowledgeBaseManager = mysql_client
        self.embeddings = embeddings
        self.faiss_clients: Dict[str, FAISS] = {}  # 存储不同 kb_id 对应的 FAISS 客户端

    def _load_all_kbs_to_memory(self):
        for kb_id in os.listdir(FAISS_LOCATION):
            faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
            debug_logger.info(f'FAISS loaded kb_id {kb_id} - {faiss_index_path}')
            if os.path.exists(faiss_index_path):
                faiss_client: FAISS = load_vector_store(faiss_index_path, self.embeddings)
                self.faiss_clients[kb_id] = faiss_client
        debug_logger.info(f'FAISS loaded all kb_ids')

    def _init_faiss_client(self, kb_id):
        faiss = dependable_faiss_import()
        index = faiss.IndexFlatL2(768)
        docstore = SelfInMemoryDocstore()
        debug_logger.info(f'init FAISS kb_id: {kb_id}')
        faiss_client: FAISS = FAISS(self.embeddings, index, docstore, index_to_docstore_id={})
        self.faiss_clients[kb_id] = faiss_client


    async def search(self, kb_ids, query, filter: Optional[Union[Callable, Dict[str, Any]]] = None,
                     top_k=VECTOR_SEARCH_TOP_K):
        if not self.faiss_clients:
            self._load_all_kbs_to_memory()

        all_docs_with_score = []
        for kb_id in kb_ids:
            faiss_client = self.faiss_clients.get(kb_id)
            if not faiss_client:
                continue

            if filter is None:
                filter = {}
            debug_logger.info(f'FAISS search: {query}, {filter}, {top_k} for kb_id: {kb_id}')
            docs_with_score = await faiss_client.asimilarity_search_with_score(query, k=top_k, filter=filter,
                                                                               fetch_k=200)
            all_docs_with_score.extend(docs_with_score)

        all_docs_with_score.sort(key=lambda x: x[1])  # 按照分数排序
        merged_docs_with_score = self.merge_docs(all_docs_with_score[:top_k])  # 只保留前 top_k 个结果
        return merged_docs_with_score

    def merge_docs(self, docs_with_score):
        merged_docs = []
        docs_with_score = sorted(docs_with_score, key=lambda x: (x[0].metadata['file_id'], x[0].metadata['chunk_id']))
        for doc, score in docs_with_score:
            doc.metadata['score'] = score
            if not merged_docs or merged_docs[-1].metadata['file_id'] != doc.metadata['file_id']:
                merged_docs.append(doc)
            else:
                if merged_docs[-1].metadata['chunk_id'] == doc.metadata['chunk_id'] - 1:
                    if num_tokens(merged_docs[-1].page_content + doc.page_content) <= 800:
                        merged_docs[-1].page_content += '\n' + doc.page_content
                        merged_docs[-1].metadata['chunk_id'] = doc.metadata['chunk_id']
                    else:
                        merged_docs.append(doc)
                else:
                    merged_docs.append(doc)
        return merged_docs

    async def add_document(self, docs):
        kb_id = docs[0].metadata['kb_id']
        if kb_id not in self.faiss_clients:
            self._init_faiss_client(kb_id)
        faiss_client = self.faiss_clients.get(kb_id)

        if not faiss_client:
            raise ValueError(f"KB with id {kb_id} not found")

        add_ids = await faiss_client.aadd_documents(docs)
        chunk_id = 0
        for doc, add_id in zip(docs, add_ids):
            self.mysql_client.add_document(add_id, chunk_id, doc.metadata['file_id'], doc.metadata['file_name'],
                                           doc.metadata['kb_id'])
            chunk_id += 1

        debug_logger.info(f'add documents number: {len(add_ids)}')
        faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
        faiss_client.save_local(faiss_index_path)
        debug_logger.info(f'save faiss index: {faiss_index_path}')
        os.chmod(os.path.dirname(faiss_index_path), stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
        return add_ids

    def delete_documents(self, kb_id, file_ids=None):
        if kb_id not in self.faiss_clients:
            self._load_all_kbs_to_memory()
        faiss_client = self.faiss_clients.get(kb_id)

        if not faiss_client:
            # 如果文档没有解析完成,是没有向量的
            debug_logger.info(f"KB with id {kb_id} not found")
            return

        if file_ids is None:
            kb_index_path = os.path.join(FAISS_LOCATION, kb_id)
            if os.path.exists(kb_index_path):
                shutil.rmtree(kb_index_path)
                del self.faiss_clients[kb_id]
                debug_logger.info(f'delete kb_id: {kb_id}, {kb_index_path}')
                return
        else:
            doc_ids = self.mysql_client.get_documents_by_file_ids(file_ids)
            doc_ids = [doc_id[0] for doc_id in doc_ids]

        if not doc_ids:
            debug_logger.info(f'no documents to delete')
            return

        try:
            res = faiss_client.delete(doc_ids)
            debug_logger.info(f'delete documents: {res}')
            faiss_index_path = os.path.join(FAISS_LOCATION, kb_id, 'faiss_index')
            faiss_client.save_local(faiss_index_path)
            debug_logger.info(f'save faiss index: {faiss_index_path}')
            os.chmod(os.path.dirname(faiss_index_path), stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
        except ValueError as e:
            debug_logger.warning(f'delete documents not find docs')

我仔细研究了一下源码,但不太理解这里为什么不同知识库切换会导致bug,是异步导致的写不一致么?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants