<a href="https://colab.research.google.com/github/juliawol/WB_Knowledge_Base/blob/main/Recall.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Load the fine-tuned model
model = SentenceTransformer('/content/fine_tuned_model_with_triplets')

# Load the chunks to choose from and the ground truth data
chunks_data_path = '/content/chunks.csv'
eval_data_path = '/content/train_data1.csv'
chunks_df = pd.read_csv(chunks_data_path)
eval_data_df = pd.read_csv(eval_data_path)

# Ensure the ground truth columns are correctly named
eval_data_df = eval_data_df.rename(columns={'Chunk': 'Ground Truth Chunk'})

# Pre-compute embeddings for all chunks in chunks.csv
original_chunks = chunks_df['Chunk'].tolist()
chunk_embeddings = model.encode(original_chunks, convert_to_tensor=True)

# Function to calculate recall at k
def calculate_recall_at_k(model, eval_data, k=5):
    correct_at_k = 0
    total = 0

    for _, row in eval_data.iterrows():
        question = row['Question']
        ground_truth_chunk = row['Ground Truth Chunk']

        # Embed the question
        question_embedding = model.encode([question], convert_to_tensor=True)

        # Calculate cosine similarity between the question and all chunk embeddings
        cosine_similarities = cosine_similarity(question_embedding.cpu().numpy(), chunk_embeddings.cpu().numpy()).flatten()

        # Get top-k most similar chunks
        top_k_indices = cosine_similarities.argsort()[-k:][::-1]
        top_k_chunks = [original_chunks[i] for i in top_k_indices]

        # Check if the ground truth chunk is in the top-k retrieved chunks
        if ground_truth_chunk in top_k_chunks:
            correct_at_k += 1
        total += 1

    # Calculate recall at k
    recall_at_k = correct_at_k / total
    return recall_at_k

# Calculate recall@1, recall@3, and recall@5
recall_at_1 = calculate_recall_at_k(model, eval_data_df, k=1)
recall_at_2 = calculate_recall_at_k(model, eval_data_df, k=2)
recall_at_3 = calculate_recall_at_k(model, eval_data_df, k=3)
recall_at_5 = calculate_recall_at_k(model, eval_data_df, k=5)

print(f"Recall@1: {recall_at_1:.4f}")
print(f"Recall@1: {recall_at_2:.4f}")
print(f"Recall@3: {recall_at_3:.4f}")
print(f"Recall@5: {recall_at_5:.4f}")