In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from langchain.text_splitter import CharacterTextSplitter
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import numpy as np
import torch
import faiss
import json
import math
import time
import os

ENCODER_PATH = "../../bge-small-en"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DOCS_PATH = "../../dataset_txt/train"
QUESTIONS_PATH = "../rag_questions_json"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class LLMEngine:
    def __init__(self):
        self.encoder = SentenceTransformer(ENCODER_PATH).to(DEVICE)

    def embed_documents(self, docs):
        return self.encoder.encode(docs)

    def embed_queries(self, queries):
        return self.encoder.encode([q[1]["question"] for q in queries])
    
engine = LLMEngine()

In [3]:
docs = [(fn.split(".")[0], open(os.path.join(DOCS_PATH, fn), 'r', encoding='utf-8').read()) for fn in tqdm(os.listdir(DOCS_PATH)) if fn.endswith(".txt")]
queries = [(fn.split(".")[0], json.load(open(os.path.join(QUESTIONS_PATH, fn), 'r', encoding='utf-8'))) for fn in tqdm(os.listdir(QUESTIONS_PATH)) if fn.endswith(".json")]

100%|██████████| 10858/10858 [01:10<00:00, 154.65it/s]
100%|██████████| 747/747 [00:00<00:00, 2327.56it/s]


In [4]:
MAX_CHAR_LEN = 4000
MAX_CHAR_OVERLAP = 500
splitter = CharacterTextSplitter(separator=" ", chunk_size=MAX_CHAR_LEN, chunk_overlap=MAX_CHAR_OVERLAP)
split_docs = []
for doc in tqdm(docs):
    split_docs.extend(splitter.split_text(doc[1]) if len(doc[1]) > MAX_CHAR_LEN else [doc[1]])
print(f"Number of documents: {len(split_docs)}")

100%|██████████| 10858/10858 [00:12<00:00, 862.74it/s]

Number of documents: 23924





In [5]:
doc_embeddings = engine.embed_documents(split_docs)
query_embeddings = engine.embed_queries(queries)

### FlatIP (Exhaustive)

In [6]:
TOP_K_DOCS = 3
D = doc_embeddings.shape[1]
quantizer = faiss.IndexFlatIP(D)
index_flatip = faiss.IndexFlatIP(D)
index_flatip.add(doc_embeddings)

begin = time.time()

for q in tqdm(query_embeddings):
    distances, indices = index_flatip.search(q[np.newaxis, :], TOP_K_DOCS)

print(f"Total time to index all questions (FlatIP): {(time.time() - begin):.2f} seconds")

100%|██████████| 747/747 [00:01<00:00, 435.34it/s]

Total time to index all questions (FlatIP): 1.72 seconds





### PQ

In [7]:
TOP_K_DOCS = 3
D = doc_embeddings.shape[1]
m = 8
assert D % m == 0
nbits = 5
index_pq = faiss.IndexPQ(D, m, nbits)
index_pq.train(doc_embeddings)
index_pq.add(doc_embeddings)

begin = time.time()

for q in tqdm(query_embeddings):
    distances, indices = index_pq.search(q[np.newaxis, :], TOP_K_DOCS)

print(f"Total time to index all questions (PQ): {(time.time() - begin):.2f} seconds")

100%|██████████| 747/747 [00:00<00:00, 936.53it/s]

Total time to index all questions (PQ): 0.80 seconds





### IVFPQ

In [8]:
TOP_K_DOCS = 3
D = doc_embeddings.shape[1]
m = 8
assert D % m == 0
nlist = 2**5
nbits = 5
quantizer = faiss.IndexFlatIP(D)
index_ivfpq = faiss.IndexIVFPQ(quantizer, D, nlist, m, nbits)
index_ivfpq.train(doc_embeddings)
index_ivfpq.add(doc_embeddings)

begin = time.time()

for q in tqdm(query_embeddings):
    distances, indices = index_ivfpq.search(q[np.newaxis, :], TOP_K_DOCS)

print(f"Total time to index all questions (IVFPQ): {(time.time() - begin):.2f} seconds")

100%|██████████| 747/747 [00:00<00:00, 15862.74it/s]

Total time to index all questions (IVFPQ): 0.05 seconds





### HNSW

In [9]:
TOP_K_DOCS = 3
D = doc_embeddings.shape[1]
M = 32
index_hnsw = faiss.IndexHNSWFlat(D, M)
index_hnsw.add(doc_embeddings)

begin = time.time()

for q in tqdm(query_embeddings):
    distances, indices = index_hnsw.search(q[np.newaxis, :], TOP_K_DOCS)

print(f"Total time to index all questions (HNSW): {(time.time() - begin):.2f} seconds")

100%|██████████| 747/747 [00:00<00:00, 12304.02it/s]

Total time to index all questions (HNSW): 0.06 seconds



