In [1]:
import os
import torch
import pandas as pd
import pickle
from tqdm import tqdm
import time
from langchain.schema import Document
from langchain.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.faiss import DistanceStrategy

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {DEVICE}')

DATA_PATH = './data/'

Using device: cuda


In [3]:
# Load data
pickle_file = os.path.join(DATA_PATH, 'prepd_data.pkl')
data = pd.read_pickle(pickle_file)

In [4]:
print(len(data))

2354


In [5]:
data.head()

Unnamed: 0,book_id,headline,combined_info,combined_info_pp,lang
0,241112,"חזון עתיד מתוקן, חגי ישראל, חזון הנביאים, תנופ...","ד. בן-גוריון י""ח אייר תש""ט 17.5.49\n\n\n\nא...","ד. בן-גוריון י""ח אייר תש""ט 17.5.49 אין עמים רב...",he
1,241170,"israel and the middle east, חברות ישראל באו""ם,...",For Israel and Middle East\n\nBy David Ben-Gur...,For Israel and Middle East By David Ben-Gurion...,en
2,241180,חוק שירות הביטחון,"ישיבה ע' ו, 5.9.1949 חוק שירות הבטחון תש""ט-...","ישיבה ע' ו, 5.9.1949 חוק שירות הבטחון תש""ט- 19...",he
3,241225,how ambassador Morgenthau saved Palestine jewr...,HOW AMBASSADOR MORGENTHAU SAVED PALESTINE JEWR...,HOW AMBASSADOR MORGENTHAU SAVED PALESTINE JEWR...,en
4,241459,"[ח""א?] במטה המורחב, צה""ל, מערכת סיני, מלחמת סי...","12.12.57 ח""א במטה המורחב\n\nאני מודה, שהדברים ...","12.12.57 ח""א במטה המורחב אני מודה, שהדברים של ...",he


In [22]:
# Convert dataframe rows to LangChain Documents
docs = [
    Document(page_content=row['combined_info_pp'], metadata={"book_id": row['book_id'], "headline": row["headline"], "source": "DBGH"})
    for idx, row in data.iterrows()
]

In [7]:
def split_text_by_words(text, max_words, overlap_words):
    # Split the text into words
    words = text.split()

    chunks = []
    start_idx = 0
    while start_idx < len(words):
        end_idx = min(start_idx + max_words, len(words))
        chunk = words[start_idx:end_idx]

        # Join the words back into text
        chunk_text = ' '.join(chunk)
        chunks.append(chunk_text)

        # Update the starting position with overlap
        start_idx += max_words - overlap_words

    return chunks

In [8]:
chunks = []
idx = 0

for doc in docs:
    temp_chunks = split_text_by_words(
        doc.page_content, max_words=260, overlap_words=35
    )
    
    for chunk in temp_chunks:
        new_metadata = dict(doc.metadata) 
        new_metadata['idx'] = idx 
        chunks.append(Document(page_content=chunk, metadata=new_metadata))
        idx += 1

In [9]:
new_data_path = os.path.join(DATA_PATH, "chunks.pkl")
with open(new_data_path, "wb") as file: 
    pickle.dump(chunks, file)

In [10]:
# Check the number of chunks created
print(f"Number of chunks: {len(chunks)}")

Number of chunks: 21529


In [None]:
api_key = "your-api-key"
os.environ["OPENAI_API_KEY"] = api_key

In [13]:
embedding_model = OpenAIEmbeddings(model="text-embedding-3-large")

In [15]:
def faiss_from_documents(documents, embedding_model, batch_size=20, delay_sec=5):
    """Embeds LangChain Documents safely in batches and builds a FAISS index."""

    texts, metadatas = [], []
    for doc in documents:
        texts.append(doc.page_content)
        metadatas.append(doc.metadata)

    all_embeddings = []
    all_texts = []
    all_metas = []

    for i in tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[i:i+batch_size]
        batch_metas = metadatas[i:i+batch_size]
        try:
            batch_embeddings = embedding_model.embed_documents(batch_texts)
            all_embeddings.extend(batch_embeddings)
            all_texts.extend(batch_texts)
            all_metas.extend(batch_metas)
        except Exception as e:
            print(f"Error at batch {i}: {e}")
            print("Pausing before retry...")
            time.sleep(delay_sec * 2)
            continue
        time.sleep(delay_sec)

    text_embedding_pairs = list(zip(all_texts, all_embeddings))
    
    # create FAISS index from precomputed embeddings
    return FAISS.from_embeddings(
        text_embedding_pairs,  
        embedding=embedding_model,
        metadatas=all_metas,
        distance_strategy=DistanceStrategy.COSINE
    )

In [16]:
vectorstore = faiss_from_documents(
    documents=chunks,
    embedding_model=embedding_model,
    batch_size=15,
)

  0%|          | 0/1436 [00:00<?, ?it/s]

100%|██████████| 1436/1436 [2:38:58<00:00,  6.64s/it] 


In [None]:
# Save FAISS index
vectorstore.save_local("./faiss_index_openai_3textlarge")

In [None]:
# Load FAISS index
vectorstore = FAISS.load_local("./faiss_index_openai_3textlarge", 
                               embedding_model,
                            allow_dangerous_deserialization=True
                               )

In [21]:
print(len(vectorstore.docstore._dict))

21529
