In [None]:
!pip install ir-datasets
import ir_datasets

dataset = ir_datasets.load('cranfield')

In [None]:
pip install nltk

In [None]:
import nltk
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import string
import pandas as pd

nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')

lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))

def preprocess_text(text):
    if not isinstance(text, str):
        return []
    text = text.lower()
    text = text.translate(str.maketrans('', '', string.punctuation))
    tokens = word_tokenize(text)
    tokens = [
        lemmatizer.lemmatize(word)
        for word in tokens
        if word not in stop_words and word.isalpha()
    ]
    return tokens

processed_docs = [' '.join(preprocess_text(doc.text)) for doc in dataset.docs_iter()]

In [None]:
processed_docs

In [None]:
from rank_bm25 import BM25Okapi
docs = [preprocess_text(doc.text) for doc in dataset.docs_iter()]
bm25 = BM25Okapi(docs, k1=2, b=0.9)

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

processed_queries = []
query_ids = []
for query in dataset.queries_iter():
    processed_queries.append(' '.join(preprocess_text(query.text)))
    query_ids.append(query.query_id)

qrels_dict = {}
for qrel in dataset.qrels_iter():
    if qrel.relevance > 0:
        qrels_dict.setdefault(qrel.query_id, set()).add(qrel.doc_id)

doc_list = list(dataset.docs_iter())
index_to_doc_id = [doc.doc_id for doc in doc_list]

In [None]:
def calculate_average_precision(ranked_doc_indices, relevant_doc_ids, doc_index_to_id):
    hits = 0
    sum_precisions = 0
    num_relevant_retrieved = 0

    for rank, doc_idx in enumerate(ranked_doc_indices):
        doc_id = doc_index_to_id[doc_idx]
        if doc_id in relevant_doc_ids:
            hits += 1
            num_relevant_retrieved += 1
            precision_at_k = hits / (rank + 1)
            sum_precisions += precision_at_k

    return sum_precisions / len(relevant_doc_ids) if len(relevant_doc_ids) > 0 else 0

In [None]:
bm25_results = {}
for query_idx, query in enumerate(processed_queries):
    query_id = query_ids[query_idx]
    query_tokens = query.split()
    scores = bm25.get_scores(query_tokens)
    bm25_sorted_doc_indices = scores.argsort()[::-1]
    bm25_results[query_id] = bm25_sorted_doc_indices

total_bm25_avg_precision = 0
num_queries_evaluated_bm25 = 0
for query_id, ranked_doc_indices in bm25_results.items():
    current_relevant_doc_ids = qrels_dict.get(query_id, set())
    if not current_relevant_doc_ids:
        continue

    num_queries_evaluated_bm25 += 1
    bm25_avg_precision = calculate_average_precision(ranked_doc_indices, current_relevant_doc_ids, index_to_doc_id)
    total_bm25_avg_precision += bm25_avg_precision

bm25_map_score = total_bm25_avg_precision / num_queries_evaluated_bm25 if num_queries_evaluated_bm25 > 0 else 0
print(f"BM25 Mean Average Precision (MAP): {bm25_map_score:.4f}")

In [None]:
def calculate_interpolated_precision(ranked_doc_indices, relevant_doc_ids, doc_index_to_id, recall_levels):
    """Calculates interpolated precision at given recall levels for a single query."""
    precision_recall_points = []
    hits = 0
    num_relevant = len(relevant_doc_ids)
    num_retrieved = 0

    if num_relevant == 0:
        return [0] * len(recall_levels)

    for doc_idx in ranked_doc_indices:
        num_retrieved += 1
        doc_id = doc_index_to_id[doc_idx]
        if doc_id in relevant_doc_ids:
            hits += 1

        current_precision = hits / num_retrieved
        current_recall = hits / num_relevant
        precision_recall_points.append((current_precision, current_recall))

    precision_recall_points.insert(0, (0, 0))
    if precision_recall_points[-1][1] < 1:
         precision_recall_points.append((precision_recall_points[-1][0], 1.0))

    precision_recall_points.sort(key=lambda x: x[1])
    interpolated_precision = []
    current_max_precision = 0

    for i in range(len(precision_recall_points) - 1, -1, -1):
         current_max_precision = max(current_max_precision, precision_recall_points[i][0])
         interpolated_precision.insert(0, (current_max_precision, precision_recall_points[i][1]))

    interpolated_precisions_at_levels = []
    current_interpolated_idx = 0

    for recall_level in recall_levels:
        found_precision = 0.0
        for prec, rec in interpolated_precision:
             if rec >= recall_level:
                  found_precision = prec
                  break
        interpolated_precisions_at_levels.append(found_precision)
    return interpolated_precisions_at_levels

In [None]:
bm25_interpolated_precisions_per_query = []
num_queries_for_11pt_eval_bm25 = 0

for query_id, ranked_doc_indices in bm25_results.items():
    current_relevant_doc_ids = qrels_dict.get(query_id, set())
    if not current_relevant_doc_ids:
        continue

    num_queries_for_11pt_eval_bm25 += 1
    interpolated_precisions = calculate_interpolated_precision(
        ranked_doc_indices, current_relevant_doc_ids, doc_index_to_id, recall_levels_11pt
    )
    bm25_interpolated_precisions_per_query.append(interpolated_precisions)

if num_queries_for_11pt_eval_bm25 > 0:
    avg_bm25_interpolated_precision = np.mean(bm25_interpolated_precisions_per_query, axis=0)
else:
    avg_bm25_interpolated_precision = [0] * len(recall_levels_11pt)

In [None]:
import numpy as np

def calculate_interpolated_f1(interpolated_precisions, rec):
    interpolated_f1 = []
    for p, r in zip(interpolated_precisions, rec):
        if p + r == 0:
            interpolated_f1.append(0.0)
        else:
            interpolated_f1.append(2 * p * r / (p + r))
    return interpolated_f1

avg_bm25_f1 = calculate_interpolated_f1(avg_bm25_interpolated_precision, avg_bm25_rec)
avg_bm25_f1

In [None]:
avg_metrics_df = pd.DataFrame({
    'interpolated_recalls_at_levels': avg_bm25_rec,
    'interpolated_precisions_at_levels': avg_bm25_interpolated_precision,
    'interpolated_f1_at_levels': avg_bm25_f1,
})

avg_metrics_df.to_csv('avg_bm25_metrics.csv', index=False)