In [1]:
from __future__ import annotations

import json
import os
from typing import Dict, Iterable, List, Tuple
import faiss
from embedding import embedding
from collections import defaultdict
from langchain_core.documents import Document
from document_processor import _load_local_documents, split_documents_to_text_chunks
import numpy as np
from config import (
    DEFAULT_BATCH_SIZE,
    EMBEDDING_DIM,
    VECTOR_INDEX_PATH,
    META_PATH,
    DEFAULT_TOP_K,
    TEST_PDFS_DIR,
)


def _batched(items: List[str], batch_size: int) -> Iterable[List[str]]:
    """
    Yield successive n-sized chunks from items for batch processing.
    """
    for i in range(0, len(items), batch_size):
        yield items[i:i + batch_size]


def _embed_texts(texts: List[str], batch_size: int = DEFAULT_BATCH_SIZE) -> np.ndarray:
    """
    Returns float32 matrix: (n, EMBEDDING_DIM)
    """
    vectors: List[np.ndarray] = []
    for batch in _batched(texts, batch_size=batch_size):
        try:
            batch_vecs = embedding(batch)
        except Exception as e:
            raise RuntimeError(
                "Embedding request failed. Check `EMBEDDING_URL` in `embedding.py` "
                "and that the service is reachable."
            ) from e

        arr = np.asarray(batch_vecs, dtype="float32")
        if arr.ndim != 2 or arr.shape[1] != EMBEDDING_DIM:
            raise ValueError(f"Expected embeddings shape (n, {EMBEDDING_DIM}), got {arr.shape}")
        vectors.append(arr)

    vectors = np.vstack(vectors).astype("float32", copy=False)
    return vectors


class FaissManager:
    def __init__(self, index_path: str = VECTOR_INDEX_PATH, meta_path: str = META_PATH):
        self.index_path = index_path
        self.meta_path = meta_path

        if os.path.exists(index_path) and os.path.exists(meta_path):
            self.index = faiss.read_index(index_path)
            self.meta = self.load_meta()
        else:  # create new index
            self.index = FaissManager.create_index()
            # inside chunks, key: chunk_id (str), value: (text, file_name)
            self.meta = {"next_id": 0, "files": defaultdict(list), "chunks": dict()}

    @staticmethod
    def create_index(dim: int = EMBEDDING_DIM) -> faiss.IndexFlatIP:
        base = faiss.IndexFlatIP(dim)
        return faiss.IndexIDMap2(base)

    def load_index(self) -> faiss.Index:
        return faiss.read_index(self.index_path)

    def save_index(self) -> None:
        # in case the config changes, and the new directory doesn't exist
        os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
        faiss.write_index(self.index, self.index_path)

    def load_meta(self) -> Dict:
        with open(self.meta_path, "r", encoding="utf-8") as f:
            meta = json.load(f)
        return meta

    def save_meta(self) -> None:
        # in case the config changes, and the new directory doesn't exist
        os.makedirs(os.path.dirname(self.meta_path), exist_ok=True)
        with open(self.meta_path, "w", encoding="utf-8") as file:
            # indent=2 for readability, ensure_ascii=False for unicode support
            json.dump(self.meta, file, ensure_ascii=False, indent=2)

    def save(self) -> None:
        self.save_index()
        self.save_meta()

    def clear(self) -> None:
        """
        Clear the FAISS index and metadata.
        """
        if os.path.exists(self.index_path):
            os.remove(self.index_path)
        if os.path.exists(self.meta_path):
            os.remove(self.meta_path)

    def reset(self) -> None:
        """
        Clear the saved index and meta file, then reset the FAISS index and metadata to initial state.
        """
        self.clear()
        self.index.reset()
        self.meta = {"next_id": 0, "files": defaultdict(list), "chunks": dict()}

    def add_chunks(self, chunks: List[Document]) -> Tuple[faiss.Index, Dict[int, str]]:
        """
        Add new texts to the FAISS index.
        """
        if not chunks:
            return
        texts = [chunk.page_content for chunk in chunks if chunk.page_content]
        file_names = [chunk.metadata.get("file_name", "unknown") for chunk in chunks if chunk.page_content]

        vectors = _embed_texts(texts)
        faiss.normalize_L2(vectors)  # in-place L2 normalization

        start_id = self.meta["next_id"]
        self.meta["next_id"] += len(texts)
        ids = np.arange(start_id, start_id + len(texts)).astype("int64")

        self.index.add_with_ids(vectors, ids)
        for i, t, f in zip(ids.tolist(), texts, file_names):
            self.meta["chunks"][str(i)] = (t, f)
            self.meta["files"][f].append(i)

    def search(self, query: str, top_k: int = DEFAULT_TOP_K) -> List[Dict[str, object]]:
        """
        Perform a top k similarity search in the FAISS index.
        """
        q_vec = _embed_texts([query])
        faiss.normalize_L2(q_vec)

        scores, ids = self.index.search(q_vec, top_k)  # a KNN search
        out: List[Dict[str, object]] = []
        for score, _id in zip(scores[0].tolist(), ids[0].tolist()):
            # when FAISS doesn't have enough results to fill top_k, it pads score and id with -1
            if _id != -1:
                text_chunk, file_name = self.meta["chunks"].get(str(_id), ("", ""))
                out.append({"id": int(_id), "score": float(score), "file_name": file_name, "text": text_chunk})
        return out

In [2]:
index_manager = FaissManager()

In [4]:
index_manager.meta['files']

{'TrainingLanguageModelsToFollowInstructionsWithHumanFeedback.pdf': [24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  110,
  111,
  112,
  113,
  114,
  115,
  116,
  117,
  118,
  119,
  120,
  121,
  122,
  123,
  124,
  125,
  126,
  127,
  128,
  129,
  130,
  131,
  132,
  133,
  134,
  135]}

In [5]:
index_manager.index.ntotal

112

In [2]:
index_manager = FaissManager()
test_documents = _load_local_documents(TEST_PDFS_DIR)
test_chunks = split_documents_to_text_chunks(test_documents)
index_manager.add_chunks(test_chunks)

C:\DAHOU\Business\go_tech\chat-pdf\data\test_data\AttentionIsAllYouNeed.pdf
C:\DAHOU\Business\go_tech\chat-pdf\data\test_data\TrainingLanguageModelsToFollowInstructionsWithHumanFeedback.pdf
Loaded 2 document(s) from 'C:\DAHOU\Business\go_tech\chat-pdf\data\test_data'


In [3]:
index_manager.index.ntotal

136

In [None]:
index_manager.meta['files']['AttentionIsAllYouNeed.pdf']

In [6]:
ids_to_remove = np.array(index_manager.meta['files']['AttentionIsAllYouNeed.pdf'], dtype=np.int64)
index_manager.index.remove_ids(faiss.IDSelectorBatch(ids_to_remove))

24

In [7]:
index_manager.index.ntotal

112

In [8]:
remaining_ids = faiss.vector_to_array(index_manager.index.id_map)

In [11]:
remaining_ids[-1]

np.int64(135)

In [ ]:
test_query = "What is natural language processing?"
results = index_manager.search(test_query, top_k=5)
print(f"Query: {test_query}")
for r in results:
    print(f"--------------------------------\nscore: {r['score']}, file name: {r['file_name']}\n")
    print(f"--------------------------------\n{r['text'][:100]}\n")  # print first 100 chars of each result

index_manager.clear()

In [4]:
224+48

272