<a href="https://colab.research.google.com/github/eduseiti/bm25_explore/blob/main/bm25_ranking_with_CISI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install rank_bm25

In [None]:
import rank_bm25

import pandas as pd
import numpy as np

import nltk
from nltk.corpus import stopwords
import string

import os
import sys
import pickle

import regex as re
import urllib
from scipy import stats

from sklearn.metrics import precision_score, recall_score, f1_score

from itertools import product

from multiprocessing import Pool

In [None]:
nltk.download('stopwords')
nltk.download('punkt')

In [None]:
PARSING_FIELDS_REGEXS={
    'identifier': ["^\.I\s+([0-9]+)"],
    'title_or_words': ["^\.T\s*\r?\n?$|^\.W\s*\r?\n?$"],
    'title_content_or_author': ["^\.A\s*\r?\n?$", "^(.+)\r\n$|^(.+)\n$"],
    'author_content_or_words': ["^\.W\s*\r?\n?$", "^(.+)\r\n$|^(.+)\n$"],
    'words_content_or_xref_or_identifier': ["^\.I\s+([0-9]+)", "^\.X\s*\r?\n?$", "^(.+)\r\n$|^(.+)\n$"],
    'words_content_or_identifier': ["^\.I\s+([0-9]+)", "^(.+)\r\n$|^(.+)\n$"],
    'xref_content_or_identifier': ["^\.I\s+([0-9]+)", "^(.+)\r\n$|^(.+)\n$"]
}

In [None]:
def read_cisi_docs_and_queries(file_url):

    all_elements = []

    current_element = {'next_field': 'identifier'}

    for line in urllib.request.urlopen(file_url):

        # print(line)
        # print(current_element)

        regex_list = PARSING_FIELDS_REGEXS[current_element['next_field']]

        for each_regex in regex_list:
            m = re.match(each_regex, line.decode())

            if m is not None:
                break;

        if m is not None:
            # Check if this match has data to store

            if len(m.groups()) > 0:

                # As there is data, check what it is and store it properly

                if current_element['next_field'] == 'identifier':
                    current_element['identifier'] = m.group(1)
                    current_element['next_field'] = 'title_or_words'

                elif current_element['next_field'] == 'title_content_or_author':
                    current_element['title'] += m.group(1) + ' '

                elif current_element['next_field'] == 'author_content_or_words':
                    current_element['author'] += m.group(1) + ';'

                elif (current_element['next_field'] == 'xref_content_or_identifier') or \
                     (current_element['next_field'] == 'words_content_or_identifier') or \
                     (current_element['next_field'] == 'words_content_or_xref_or_identifier'):

                    if m.group(0)[0:2] == '.I':
                        # Document complete

                        # print(current_element)

                        all_elements.append(current_element)

                        current_element = {'identifier': m.group(1),
                                           'next_field': 'title_or_words'}
                    else:
                        if current_element['next_field'] == 'xref_content_or_identifier':
                            current_element['xref'] += m.group(1) + ';'
                        else:
                            if m.group(1) is not None:
                                current_element['words'] += m.group(1) + ' '

            else:

                # This is a tag-only entry

                if current_element['next_field'] == 'title_or_words':
                    if m.group(0)[0:2] == '.T':
                        current_element['title'] = ""
                        current_element['next_field'] = 'title_content_or_author'
                    else:
                        current_element['words'] = ""
                        current_element['next_field'] = 'words_content_or_identifier'

                elif current_element['next_field'] == 'title_content_or_author':
                    current_element['author'] = ""
                    current_element['next_field'] = 'author_content_or_words'

                elif current_element['next_field'] == 'author_content_or_words':
                    current_element['words'] = ""
                    current_element['next_field'] = 'words_content_or_xref_or_identifier'

                elif current_element['next_field'] == 'words_content_or_xref_or_identifier':
                    current_element['xref'] = ""
                    current_element['next_field'] = 'xref_content_or_identifier'

                elif (current_element['next_field'] == 'xref_content_or_identifier') \
                     (current_element['next_field'] == 'words_content_or_identifier'):

                    # Document complete

                    # print(current_element)

                    all_elements.append(current_element)

                    current_element = {'next_field': 'identifier'}

                else:
                    print("Just ignore the line")

        # break

    if current_element['next_field'] != 'title_or_words':
        all_elements.append(current_element)

    print("Parsed {} elements...".format(len(all_elements)))  

    return pd.DataFrame(all_elements).drop(columns='next_field')

# Read CISI files

In [None]:
docs_df = read_cisi_docs_and_queries('https://raw.githubusercontent.com/eduseiti/bm25_explore/main/cisi/CISI.ALL')

In [None]:
queries_df = read_cisi_docs_and_queries('https://raw.githubusercontent.com/eduseiti/bm25_explore/main/cisi/CISI.QRY')

In [None]:
qrels_df = pd.read_csv('https://raw.githubusercontent.com/eduseiti/bm25_explore/main/cisi/CISI.REL', 
                       sep='\t', 
                       header=None, 
                       names=['query_id', 'doc_id', 'Q0', 'rel'])

# Tokenize and clean stop words from reference docs and queries

In [None]:
def tokenize_and_remove_stop_words(which_df, stop_words, punctuation):

    all_tokens = [nltk.word_tokenize(doc.lower()) for doc in which_df['words']]

    cleaned_tokens = [[token for token in doc_tokens if token not in stop_words and token not in punctuation] for doc_tokens in all_tokens]

    return cleaned_tokens

In [None]:
stop_words = set(stopwords.words('english'))
punctuation = set(string.punctuation)

In [None]:
docs_tokens = tokenize_and_remove_stop_words(docs_df, stop_words, punctuation)

In [None]:
len(docs_tokens)

In [None]:
queries_tokens = tokenize_and_remove_stop_words(queries_df, stop_words, punctuation)

In [None]:
len(queries_tokens)

# Compute BM25 scores for each query / document pair

In [None]:
def compute_BM25(docs_tokens, queries_tokens, qrels_df, bm25_params, score_threshold = 1e-5):

    # print("k1={}, b={}".format(k1, b))

    docs_bm25_scores = rank_bm25.BM25Okapi(docs_tokens, k1=bm25_params[0], b=bm25_params[1])

    docs_queries_scores = []

    for query_tokens in queries_tokens:
        query_scores = docs_bm25_scores.get_scores(query_tokens)

        docs_queries_scores.append(query_scores)

    results = []

    # Evaluate the retrieval performance using precision, recall, and F1-score

    query_ids = qrels_df['query_id'].unique()

    for query_id in query_ids:

        inferred_relevant_docs = docs_queries_scores[query_id - 1] > score_threshold

        gt_relevant_docs = np.zeros(inferred_relevant_docs.shape[0], dtype=bool)
        gt_relevant_docs[qrels_df[qrels_df['query_id'] == query_id]['doc_id'].to_numpy() - 1] = True

        precision = precision_score(gt_relevant_docs, inferred_relevant_docs)
        recall = recall_score(gt_relevant_docs, inferred_relevant_docs)
        f1 = f1_score(gt_relevant_docs, inferred_relevant_docs)

        results.append({'query_id': query_id,
                        'k1': bm25_params[0],
                        'b': bm25_params[1],
                        'score_threshold': score_threshold,
                        'precision': precision, 
                        'recall': recall, 
                        'f1': f1})

    results_df = pd.DataFrame(results)

    return results_df


In [None]:
def check_score_threshold(docs_tokens, queries_tokens, qrels_df, k1, b, score_thresholds):
    
    docs_bm25_scores = rank_bm25.BM25Okapi(docs_tokens, k1=k1, b=b)
    
    docs_queries_scores = []

    for query_tokens in queries_tokens:
        query_scores = docs_bm25_scores.get_scores(query_tokens)

        docs_queries_scores.append(query_scores)

    query_ids = qrels_df['query_id'].unique()

    results = []
    
    for score_threshold in score_thresholds:
        
        print("Evaluating score threshold={:.2f}...".format(score_threshold))
        
        for query_id in query_ids:

            inferred_relevant_docs = docs_queries_scores[query_id - 1] > score_threshold

            gt_relevant_docs = np.zeros(inferred_relevant_docs.shape[0], dtype=bool)
            gt_relevant_docs[qrels_df[qrels_df['query_id'] == query_id]['doc_id'].to_numpy() - 1] = True

            precision = precision_score(gt_relevant_docs, inferred_relevant_docs)
            recall = recall_score(gt_relevant_docs, inferred_relevant_docs)
            f1 = f1_score(gt_relevant_docs, inferred_relevant_docs)

            results.append({'query_id': query_id,
                            'k1': k1,
                            'b': b,
                            'score_threshold': score_threshold,
                            'precision': precision, 
                            'recall': recall, 
                            'f1': f1})

    results_df = pd.DataFrame(results)

    results_stats_df = results_df.groupby(['score_threshold'])[['precision', 'recall', 'f1']].mean().reset_index()
    
    return results_df, results_stats_df

## Grid search on BM25 hyperparameters and fixed score threshold

In [None]:
k1_values = [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
b_values = [0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75]

In [None]:
bm25_parameters = list(product(k1_values, b_values))

In [None]:
bm25_parameters

In [None]:
with Pool(processes=6) as pool:
    all_results = pool.starmap(compute_BM25, zip([docs_tokens] * len(bm25_parameters),
                                                 [queries_tokens] * len(bm25_parameters),
                                                 [qrels_df] * len(bm25_parameters),
                                                 bm25_parameters,
                                                 [5] * len(bm25_parameters)))

In [None]:
all_results_df = pd.concat(all_results)

In [None]:
all_results_stats_df = all_results_df.groupby(['k1', 'b'])[['precision', 'recall', 'f1']].mean().reset_index()

In [None]:
all_results_stats_df[all_results_stats_df['recall'] == all_results_stats_df['recall'].max()]

In [None]:
all_results_stats_df[all_results_stats_df['precision'] == all_results_stats_df['precision'].max()]

## Grid search to define the best BM25 score threshold

In [None]:
score_thresholds = np.arange(1, 20, 0.1)

In [None]:
score_thresholds

### Try applying the BM25 hyperparameters with the best recall

In [None]:
results_df, results_stats_df = check_score_threshold(docs_tokens, queries_tokens, qrels_df, 2.0, 0.2, score_thresholds)

In [None]:
results_stats_df[results_stats_df['recall'] == results_stats_df['recall'].max()]

In [None]:
results_stats_df[results_stats_df['precision'] == results_stats_df['precision'].max()]

In [None]:
results_stats_df[results_stats_df['f1'] == results_stats_df['f1'].max()]

### Now, check the BM25 hyperparameters with best precision

In [None]:
results_2_df, results_2_stats_df = check_score_threshold(docs_tokens, queries_tokens, qrels_df, 1.0, 0.55, score_thresholds)

In [None]:
results_2_stats_df[results_2_stats_df['recall'] == results_2_stats_df['recall'].max()]

In [None]:
results_2_stats_df[results_2_stats_df['precision'] == results_2_stats_df['precision'].max()]

In [None]:
results_2_stats_df[results_2_stats_df['f1'] == results_2_stats_df['f1'].max()]