In [1]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
import os
import pickle

In [2]:
pickle_file = "../data/all_docs_infoboxes_final.pkl"

if os.path.exists(pickle_file):
    with open(pickle_file, "rb") as f:
        all_docs = pickle.load(f)
        print(f"Loaded all_docs from {pickle_file}")
else:
    print(f"Could not find {pickle_file}")

Loaded all_docs from ../data/all_docs_infoboxes_final.pkl


In [3]:
# 3. Now we have all the documents in all_docs
# 4. Apply the RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=56)
chunked_docs = text_splitter.split_documents(all_docs)

# Check how many chunks we got
print(f"\nNumber of total chunks: {len(chunked_docs)}")


Number of total chunks: 35057


In [4]:
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from tqdm import tqdm

embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/msmarco-distilbert-base-tas-b",
    #model_kwargs={"device": "cpu"}
)

# Define batch size
batch_size = 10  # Adjust based on memory constraints

# Calculate total number of batches
total_batches = (len(chunked_docs) + batch_size - 1) // batch_size

# Initialize FAISS DB as None
faiss_db = None

# Use tqdm for the embedding progress
with tqdm(total=total_batches, desc="Embedding and Building FAISS DB") as pbar:
    for i in range(0, len(chunked_docs), batch_size):
        batch = chunked_docs[i : i + batch_size]
        
        # Create FAISS index for the current batch
        tmp_db = FAISS.from_documents(batch, embeddings)
        
        if faiss_db is None:
            faiss_db = tmp_db
        else:
            faiss_db.merge_from(tmp_db)
        
        # Update the progress bar
        pbar.update(1)

  embeddings = HuggingFaceEmbeddings(
Embedding and Building FAISS DB: 100%|██████████| 3506/3506 [10:52<00:00,  5.37it/s]


In [5]:
# Save the FAISS vector store
faiss_db.save_local("../data/faiss_db_infoboxes_final")