In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import warnings
import re

from sklearn.metrics import precision_score, recall_score, f1_score
from sentence_transformers import SentenceTransformer
from OnlineKMeans import OnlineKMeans
import faiss

warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = SentenceTransformer("all-MiniLM-L6-v2")

In [3]:
def is_answer_in_chunk(answer_start, chunk_start, chunk_length):
    if answer_start is None or chunk_start is None or chunk_length is None:
        return False
    return chunk_start <= answer_start < (chunk_start + chunk_length)

In [4]:
def retrieve_top_chunks_faiss(
    query_embedding,
    df_chunks,
    chunk_embeddings,
    faiss_index,
    top_k=5
):
    """
    Retrieve top-k chunks using a FAISS index
    """
    query_vec = np.expand_dims(query_embedding.astype(np.float32), axis=0)
    sims, idxs = faiss_index.search(query_vec, top_k)
    
    results = []
    for sim, idx in zip(sims[0], idxs[0]):
        results.append({
            "context_id": df_chunks.iloc[idx]["context_id"],
            "chunk_id": df_chunks.iloc[idx]["chunk_id"],
            "title": df_chunks.iloc[idx]["title"],
            "chunk_embed_text": df_chunks.iloc[idx]["chunk_embed_text"],
            "chunk_start": df_chunks.iloc[idx]["chunk_start"],
            "chunk_end": df_chunks.iloc[idx]["chunk_end"],
            "similarity": sim
        })
    return pd.DataFrame(results).sort_values("similarity", ascending=False)

In [5]:
def evaluate_top_k_accuracy_faiss(
    df_queries,
    df_chunks,
    chunk_embeddings,
    faiss_index,
    top_k=5
):
    y_true_doc, y_pred_doc = [], []
    y_true_chunk, y_pred_chunk = [], []

    chunk_ratios = []

    for _, row in tqdm(df_queries.iterrows(), total=len(df_queries)):
        query_emb = model.encode([row["question"]])[0]
        results = retrieve_top_chunks_faiss(
            query_embedding=query_emb,
            df_chunks=df_chunks,
            chunk_embeddings=chunk_embeddings,
            faiss_index=faiss_index,
            top_k=top_k
        )

        # Document-level
        found_doc_id = any(row["context_id"] == doc_id for doc_id in results["context_id"])
        y_true_doc.append(1)
        y_pred_doc.append(1 if found_doc_id else 0)

        correct_doc_chunks = results[results["context_id"] == row["context_id"]]
        found_chunk_context = any(
            is_answer_in_chunk(
                row["answer_start"],
                chunk["chunk_start"],
                chunk["chunk_end"] - chunk["chunk_start"]
            )
            for _, chunk in correct_doc_chunks.iterrows()
        )
        good_chunks = len(correct_doc_chunks)
        total_chunks = results.shape[0]
        ratio = good_chunks / total_chunks
        chunk_ratios.append(ratio)

        y_true_chunk.append(1)
        y_pred_chunk.append(1 if found_chunk_context else 0)

    # Compute metrics
    chunk_accuracy = sum(chunk_ratios) / len(chunk_ratios) if len(chunk_ratios) > 0 else 0
    metrics = {
        "doc_accuracy": sum(y_pred_doc) / len(y_pred_doc),
        "chunk_accuracy": sum(y_pred_chunk) / len(y_pred_chunk),
        "doc_precision": precision_score(y_true_doc, y_pred_doc, zero_division=0),
        "doc_recall": recall_score(y_true_doc, y_pred_doc, zero_division=0),
        "doc_f1": f1_score(y_true_doc, y_pred_doc, zero_division=0),
        "chunk_precision": precision_score(y_true_chunk, y_pred_chunk, zero_division=0),
        "chunk_recall": recall_score(y_true_chunk, y_pred_chunk, zero_division=0),
        "chunk_f1": f1_score(y_true_chunk, y_pred_chunk, zero_division=0),
        "correct_chunk_accuracy": chunk_accuracy
    }

    return metrics

In [6]:
# def online_kmeans_faiss_retrieval(
#     chunk_embeddings,
#     df_chunks,
#     df_queries,
#     n_clusters=500,
#     batch_size=2000,
#     top_k=5,
#     init_fraction=0.5,
#     max_clusters=None,
#     metric="cosine",
#     new_cluster_threshold=None,
#     merge_threshold=None,
#     decay=None
# ):
#     n_samples = chunk_embeddings.shape[0]
#     init_size = int(n_samples * init_fraction)
#     remaining_size = n_samples - init_size

#     # --- Step 1: Initialize OnlineKMeans ---
#     print(f"🔧 Using {init_fraction*100:.0f}% of data ({init_size} samples) for initialization")
#     okm = OnlineKMeans(
#         n_clusters=n_clusters,
#         max_clusters=max_clusters,
#         metric=metric,
#         new_cluster_threshold=new_cluster_threshold,
#         merge_threshold=merge_threshold,
#         random_state=42,
#         decay=decay
#     )
#     init_start = time.time()
#     okm.partial_fit(chunk_embeddings[:init_size])
#     init_end = time.time()
#     init_time = init_end - init_start
#     print(f"✅ Initialization done in {init_time:.4f} s")

#     results = []

#     # --- Step 2: Online batch updates ---
#     for batch_idx in tqdm(range(1, int(np.ceil(remaining_size / batch_size)) + 1)):
#         start_idx = (batch_idx - 1) * batch_size
#         end_idx = min(batch_idx * batch_size, remaining_size)
#         batch_embeddings = chunk_embeddings[init_size + start_idx : init_size + end_idx].astype(np.float32, copy=False)
#         batch_embeddings = np.ascontiguousarray(batch_embeddings)

#         # --- Online update ---
#         update_start = time.time()
#         okm.partial_fit(batch_embeddings)
#         update_end = time.time()
#         update_time = update_end - update_start

#         # --- Seen data so far ---
#         seen_end_idx = init_size + end_idx
#         seen_embeddings = chunk_embeddings[:seen_end_idx].astype(np.float32, copy=False)
#         seen_df_chunks = df_chunks.iloc[:seen_end_idx].reset_index(drop=True)

#         # --- Build FAISS index ---
#         d = seen_embeddings.shape[1]
#         if metric == "cosine":
#             faiss_index = faiss.IndexFlatIP(d)  # inner product for cosine similarity
#             faiss.normalize_L2(seen_embeddings)
#         else:
#             faiss_index = faiss.IndexFlatL2(d)
#         faiss_index.add(seen_embeddings)

#         # --- Filter queries ---
#         df_queries_seen = df_queries[df_queries["context_id"].isin(seen_df_chunks["context_id"].unique())].reset_index(drop=True)

#         # --- Evaluate retrieval ---
#         retrieval_start = time.time()
#         metrics = evaluate_top_k_accuracy_faiss(
#             df_queries=df_queries_seen,
#             df_chunks=seen_df_chunks,
#             chunk_embeddings=seen_embeddings,
#             faiss_index=faiss_index,
#             top_k=top_k
#         )
#         retrieval_end = time.time()
#         retrieval_time = retrieval_end - retrieval_start

#         results.append({
#             "batch": batch_idx,
#             "init_time": init_time if batch_idx == 1 else 0,
#             "update_time": update_time,
#             "retrieval_time": retrieval_time,
#             "metrics": metrics,
#             "n_clusters": len(okm.centroids)
#         })

#         print(f"[Batch {batch_idx}] Seen chunks: {seen_end_idx}, Doc acc: {metrics['doc_accuracy']:.4f}, Chunk acc: {metrics['chunk_accuracy']:.4f}, Clusters: {len(okm.centroids)}")

#     return pd.DataFrame(results)

In [7]:
def online_kmeans_faiss_retrieval(
    chunk_embeddings,
    df_chunks,
    df_queries,
    n_clusters=500,
    batch_size=2000,
    top_k=5,
    init_fraction=0.5,
    max_clusters=None,
    metric="cosine",
    new_cluster_threshold=None,
    merge_threshold=None,
    decay=None
):
    n_samples = chunk_embeddings.shape[0]
    init_size = int(n_samples * init_fraction)
    remaining_size = n_samples - init_size

    # --- Step 1: Initialize OnlineKMeans ---
    print(f"🔧 Using {init_fraction*100:.0f}% of data ({init_size} samples) for initialization")
    okm = OnlineKMeans(
        n_clusters=n_clusters,
        max_clusters=max_clusters,
        metric=metric,
        new_cluster_threshold=new_cluster_threshold,
        merge_threshold=merge_threshold,
        random_state=42,
        decay=decay
    )
    init_start = time.time()
    init_embeddings = chunk_embeddings[:init_size].astype(np.float32)
    okm.partial_fit(init_embeddings)
    init_end = time.time()
    init_time = init_end - init_start
    print(f"✅ Initialization done in {init_time:.4f} s")

    # --- Step 2: Create FAISS index once ---
    d = chunk_embeddings.shape[1]
    if metric == "cosine":
        faiss_index = faiss.IndexFlatIP(d)
        faiss.normalize_L2(init_embeddings)
    else:
        faiss_index = faiss.IndexFlatL2(d)

    faiss_index.add(init_embeddings)

    # Mapping from FAISS index IDs → chunk/context IDs
    index_to_chunk = df_chunks.iloc[:init_size]["context_id"].tolist()

    results = []

    # --- Step 3: Online batch updates ---
    for batch_idx in tqdm(range(1, int(np.ceil(remaining_size / batch_size)) + 1)):
        start_idx = (batch_idx - 1) * batch_size
        end_idx = min(batch_idx * batch_size, remaining_size)
        batch_embeddings = chunk_embeddings[init_size + start_idx : init_size + end_idx].astype(np.float32)
        batch_embeddings = np.ascontiguousarray(batch_embeddings)

        # --- Online update ---
        update_start = time.time()
        okm.partial_fit(batch_embeddings)
        update_end = time.time()
        update_time = update_end - update_start

        # --- Incrementally add batch to FAISS ---
        if metric == "cosine":
            faiss.normalize_L2(batch_embeddings)
        faiss_index.add(batch_embeddings)

        # Update index_to_chunk mapping
        index_to_chunk.extend(df_chunks.iloc[init_size + start_idx : init_size + end_idx]["context_id"].tolist())

        # --- Prepare queries for seen chunks ---
        # Only keep queries whose context is in the FAISS index
        df_queries_seen = df_queries[df_queries["context_id"].isin(index_to_chunk)].reset_index(drop=True)

        # --- Evaluate retrieval ---
        retrieval_start = time.time()
        metrics = evaluate_top_k_accuracy_faiss(
            df_queries=df_queries_seen,
            df_chunks=df_chunks,  # pass full df_chunks if needed
            chunk_embeddings=None,  # embeddings not needed, FAISS handles search
            faiss_index=faiss_index,
            index_to_chunk=index_to_chunk,  # mapping for correct chunk IDs
            top_k=top_k
        )
        retrieval_end = time.time()
        retrieval_time = retrieval_end - retrieval_start

        results.append({
            "batch": batch_idx,
            "init_time": init_time if batch_idx == 1 else 0,
            "update_time": update_time,
            "retrieval_time": retrieval_time,
            "metrics": metrics,
            "n_clusters": len(okm.centroids)
        })

        print(f"[Batch {batch_idx}] Seen chunks: {len(index_to_chunk)}, Doc acc: {metrics['doc_accuracy']:.4f}, Chunk acc: {metrics['chunk_accuracy']:.4f}, Clusters: {len(okm.centroids)}")

    return pd.DataFrame(results)


# Workflow

In [8]:
X_semantic_train = np.load("../data/tensors/squad_train_v2_semantic_chunking.npy").astype(np.float32)
df_semantic_train = pd.read_excel("../data/prepared/squad_train_v2_semantic_chunking.xlsx")
df_queries_train = pd.read_excel("../data/prepared/squad_train_v2_queries.xlsx")

In [9]:
results_df = online_kmeans_faiss_retrieval(
    chunk_embeddings=X_semantic_train,
    df_chunks=df_semantic_train,
    df_queries=df_queries_train,
    n_clusters=500,
    max_clusters=2000,
    batch_size=2000,
    top_k=5,
    init_fraction=0.5,
    merge_threshold=0.08,
    decay=1.0,
    new_cluster_threshold=0.8
)

🔧 Using 50% of data (42003 samples) for initialization


: 

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(results_df["batch"], results_df["metrics"].apply(lambda x: x['doc_accuracy']), marker="o", label="Doc Accuracy")
plt.plot(results_df["batch"], results_df["metrics"].apply(lambda x: x['chunk_accuracy']), marker="s", label="Chunk Accuracy")
plt.xlabel("Batch")
plt.ylabel("Accuracy")
plt.title("📊 Retrieval Accuracy per Batch (OnlineKMeans + FAISS)")
plt.legend()
plt.grid(True)
plt.show()

# --- Plot Runtime ---
plt.figure(figsize=(10, 5))
plt.plot(results_df["batch"], results_df["update_time"], label="Update Time", marker='o')
plt.plot(results_df["batch"], results_df["retrieval_time"], label="Retrieval Time", marker='s')
plt.xlabel("Batch")
plt.ylabel("Time (s)")
plt.title("⚙️ Runtime per Batch (OnlineKMeans + FAISS)")
plt.legend()
plt.grid(True)
plt.show()

# --- Plot Cluster count ---
plt.figure(figsize=(10, 5))
plt.plot(results_df["batch"], results_df["n_clusters"], marker='o', color='purple')
plt.xlabel("Batch")
plt.ylabel("# Clusters")
plt.title("📈 Cluster Count Evolution (OnlineKMeans)")
plt.grid(True)
plt.show()