# Demo retrieval notebook

In [6]:
import os
import sys
import glob
import pandas as pd
import numpy as np

# Add the parent directory (one level up) to sys.path to access our modules
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from modules.extraction.preprocessing import DocumentProcessing
from modules.extraction.embedding import Embedding
from modules.retrieval.index.bruteforce import FaissBruteForce


# Parameters
STORAGE_DIRECTORY = os.path.abspath(os.path.join(os.getcwd(), "..", "storage"))
CHUNK_SIZE = 500         # characters per chunk for fixed-length chunking
OVERLAP_SIZE = 2         # overlapping characters between chunks
TOP_K = 5                # number of nearest neighbors to retrieve

def build_index(embedding_model):
    """
    Builds a FAISS index using the specified embedding model by processing all files
    in the storage folder. Each file's article identifier is extracted from its filename.
    """
    processing = DocumentProcessing()
    embedding_instance = Embedding(embedding_model)
    
    # Get all files directly in the storage folder
    document_files = glob.glob(os.path.join(STORAGE_DIRECTORY, "*"))
    
    all_embeddings = []
    all_metadata = []
    
    for file_path in document_files:
        # Extract the article identifier from the filename.
        # For example, "S08_set3_a4.txt.clean" -> "S08_set3_a4"
        base = os.path.basename(file_path)
        article_id = base.replace('.txt.clean', '')
        
        # Use fixed-length chunking (implemented in preprocessing.py)
        chunks = processing.fixed_length_chunking(file_path, chunk_size=CHUNK_SIZE, overlap_size=OVERLAP_SIZE)
        
        for chunk in chunks:
            if chunk.strip():
                vector = embedding_instance.encode(chunk)
                all_embeddings.append(vector)
                # Save metadata as a dict that includes both the article id and the chunk text.
                all_metadata.append({"article": article_id, "chunk": chunk})
    
    if not all_embeddings:
        raise ValueError("No embeddings generated. Check your document files and chunking parameters.")
    
    # Assume all embeddings have the same dimensionality.
    dim = all_embeddings[0].shape[0]
    
    # Create the FAISS index using the FaissBruteForce class (defaulting to Euclidean metric)
    index = FaissBruteForce(dim, metric='euclidean')
    index.add_embeddings(all_embeddings, all_metadata)
    return index

# Build two indices â€“ one for each embedding model.
print("Building index for all-MiniLM-L6-v2 ...")
index_mini = build_index("all-MiniLM-L6-v2")
print("Building index for all-mpnet-base-v2 ...")
index_mpnet = build_index("all-mpnet-base-v2")

# Load the questions TSV file (located one level up in qa_resources/)
qa_file = os.path.abspath(os.path.join(os.getcwd(), "..", "qa_resources", "question.tsv"))
questions_df = pd.read_csv(qa_file, sep="\t")

# We assume the TSV file has at least these two columns: "Question" and "ArticleFile"
def evaluate_index(index, embedding_model):
    embedding_instance = Embedding(embedding_model)
    correct = 0
    total = 0
    ranks = []  # to record ranking positions of the first true positive
    
    for _, row in questions_df.iterrows():
        # Convert question to a string in case it is not already one.
        question = str(row["Question"])
        target_article = row["ArticleFile"]  # e.g. "S08_set3_a4"
        
        # Encode the question.
        q_vector = embedding_instance.encode(question)
        q_vector = np.array(q_vector).astype('float32').reshape(1, -1)
        
        # Retrieve top K neighbors from the index.
        distances, indices = index.index.search(q_vector, TOP_K)
        
        # Look up metadata for the retrieved indices.
        retrieved_metadata = [index.metadata[i] for i in indices[0]]
        
        # Check ranking: record the rank of the first chunk whose 'article' matches target_article.
        rank = None
        for i, meta in enumerate(retrieved_metadata):
            if meta["article"] == target_article:
                rank = i + 1  # ranks are 1-indexed
                break
        
        if rank is not None:
            correct += 1
            ranks.append(rank)
        else:
            ranks.append(None)
        total += 1
    
    accuracy = correct / total if total > 0 else 0
    valid_ranks = [r for r in ranks if r is not None]
    avg_rank = sum(valid_ranks) / len(valid_ranks) if valid_ranks else None
    return accuracy, avg_rank

# Evaluate both indices.
accuracy_mini, avg_rank_mini = evaluate_index(index_mini, "all-MiniLM-L6-v2")
accuracy_mpnet, avg_rank_mpnet = evaluate_index(index_mpnet, "all-mpnet-base-v2")

print("Performance for all-MiniLM-L6-v2:")
print("  Accuracy (true positive in top {}): {:.2f}%".format(TOP_K, accuracy_mini * 100))
print("  Average rank of first true positive:", avg_rank_mini)

print("\nPerformance for all-mpnet-base-v2:")
print("  Accuracy (true positive in top {}): {:.2f}%".format(TOP_K, accuracy_mpnet * 100))
print("  Average rank of first true positive:", avg_rank_mpnet)



Building index for all-MiniLM-L6-v2 ...
Building index for all-mpnet-base-v2 ...
Performance for all-MiniLM-L6-v2:
  Accuracy (true positive in top 5): 83.82%
  Average rank of first true positive: 1.1713961407491487

Performance for all-mpnet-base-v2:
  Accuracy (true positive in top 5): 84.40%
  Average rank of first true positive: 1.125140924464487
