<a href="https://colab.research.google.com/github/eduseiti/bm25_explore/blob/main/bm25_ranking_with_CISI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install rank_bm25

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import rank_bm25
import pandas as pd
import numpy as np

import nltk
from nltk.corpus import stopwords
import string

import os
import sys
import pickle

import regex as re
import urllib
from scipy import stats

from sklearn.metrics import precision_score, recall_score, f1_score

In [None]:
nltk.download('stopwords')
nltk.download('punkt')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
PARSING_FIELDS_REGEXS={
  'identifier': ["^\.I\s+([0-9]+)"],
  'title_or_words': ["^\.T\s*\r?\n?$|^\.W\s*\r?\n?$"],
  'title_content_or_author': ["^\.A\s*\r?\n?$", "^(.+)\r\n$|^(.+)\n$"],
  'author_content_or_words': ["^\.W\s*\r?\n?$", "^(.+)\r\n$|^(.+)\n$"],
  'words_content_or_xref_or_identifier': ["^\.I\s+([0-9]+)", "^\.X\s*\r?\n?$", "^(.+)\r\n$|^(.+)\n$"],
  'words_content_or_identifier': ["^\.I\s+([0-9]+)", "^(.+)\r\n$|^(.+)\n$"],
  'xref_content_or_identifier': ["^\.I\s+([0-9]+)", "^(.+)\r\n$|^(.+)\n$"]
}

def read_cisi_docs_and_queries(file_url):

  all_elements = []

  current_element = {'next_field': 'identifier'}

  for line in urllib.request.urlopen(file_url):

    # print(line)
    # print(current_element)

    regex_list = PARSING_FIELDS_REGEXS[current_element['next_field']]

    for each_regex in regex_list:
      m = re.match(each_regex, line.decode())

      if m is not None:
        break;

    if m is not None:

      # Check if this match has data to store

      if len(m.groups()) > 0:

        # As there is data, check what it is and store it properly

        if current_element['next_field'] == 'identifier':
          current_element['identifier'] = m.group(1)
          current_element['next_field'] = 'title_or_words'

        elif current_element['next_field'] == 'title_content_or_author':
          current_element['title'] += m.group(1) + ' '

        elif current_element['next_field'] == 'author_content_or_words':
          current_element['author'] += m.group(1) + ';'

        elif (current_element['next_field'] == 'xref_content_or_identifier') or \
             (current_element['next_field'] == 'words_content_or_identifier') or \
             (current_element['next_field'] == 'words_content_or_xref_or_identifier'):

          if m.group(0)[0:2] == '.I':
            # Document complete
            
            # print(current_element)

            all_elements.append(current_element)
          
            current_element = {'identifier': m.group(1),
                               'next_field': 'title_or_words'}
          else:
            if current_element['next_field'] == 'xref_content_or_identifier':
              current_element['xref'] += m.group(1) + ';'
            else:
              if m.group(1) is not None:
                current_element['words'] += m.group(1) + ' '

      else:

        # This is a tag-only entry

        if current_element['next_field'] == 'title_or_words':
          if m.group(0)[0:2] == '.T':
            current_element['title'] = ""
            current_element['next_field'] = 'title_content_or_author'
          else:
            current_element['words'] = ""
            current_element['next_field'] = 'words_content_or_identifier'

        elif current_element['next_field'] == 'title_content_or_author':
          current_element['author'] = ""
          current_element['next_field'] = 'author_content_or_words'

        elif current_element['next_field'] == 'author_content_or_words':
          current_element['words'] = ""
          current_element['next_field'] = 'words_content_or_xref_or_identifier'

        elif current_element['next_field'] == 'words_content_or_xref_or_identifier':
          current_element['xref'] = ""
          current_element['next_field'] = 'xref_content_or_identifier'

        elif (current_element['next_field'] == 'xref_content_or_identifier') \
             (current_element['next_field'] == 'words_content_or_identifier'):
          
          # Document complete
          
          # print(current_element)

          all_elements.append(current_element)

          current_element = {'next_field': 'identifier'}

        else:
          print("Just ignore the line")

    # break

  if current_element['next_field'] != 'title_or_words':
    all_elements.append(current_element)

  print("Parsed {} elements...".format(len(all_elements)))  

  return pd.DataFrame(all_elements).drop(columns='next_field')

# Read CISI files

In [None]:
docs_df = read_cisi_docs_and_queries('https://raw.githubusercontent.com/eduseiti/bm25_explore/main/cisi/CISI.ALL')

Parsed 1460 elements...


In [None]:
queries_df = read_cisi_docs_and_queries('https://raw.githubusercontent.com/eduseiti/bm25_explore/main/cisi/CISI.QRY')

Parsed 112 elements...


In [None]:
qrels_df = pd.read_csv('https://raw.githubusercontent.com/eduseiti/bm25_explore/main/cisi/CISI.REL', 
                       sep='\t', 
                       header=None, 
                       names=['query_id', 'doc_id', 'Q0', 'rel'])

# Tokenize and clean stop words from reference docs and queries

In [None]:
def tokenize_and_remove_stop_words(which_df, stop_words, punctuation):

  all_tokens = [nltk.word_tokenize(doc.lower()) for doc in which_df['words']]

  cleaned_tokens = [[token for token in doc_tokens if token not in stop_words and token not in punctuation] for doc_tokens in all_tokens]

  return cleaned_tokens

In [None]:
stop_words = set(stopwords.words('english'))
punctuation = set(string.punctuation)

In [None]:
docs_tokens = tokenize_and_remove_stop_words(docs_df, stop_words, punctuation)

In [None]:
len(docs_tokens)

1460

In [None]:
queries_tokens = tokenize_and_remove_stop_words(queries_df, stop_words, punctuation)

In [None]:
len(queries_tokens)

112

# Compute BM25 scores for each query / document pair

In [None]:
def compute_BM25(docs_tokens, queries_tokens, qrels_df, k1, b, score_threshold = 1e-5):

  # print("k1={}, b={}".format(k1, b))

  docs_bm25_scores = rank_bm25.BM25Okapi(docs_tokens, k1=k1, b=b)

  docs_queries_scores = []

  for query_tokens in queries_tokens:
    query_scores = docs_bm25_scores.get_scores(query_tokens)

    docs_queries_scores.append(query_scores)

  results = []

  # Evaluate the retrieval performance using precision, recall, and F1-score

  query_ids = qrels_df['query_id'].unique()

  for query_id in query_ids:

      inferred_relevant_docs = docs_queries_scores[query_id - 1] > score_threshold

      gt_relevant_docs = np.zeros(inferred_relevant_docs.shape[0], dtype=bool)
      gt_relevant_docs[qrels_df[qrels_df['query_id'] == query_id]['doc_id'].to_numpy() - 1] = True

      precision = precision_score(gt_relevant_docs, inferred_relevant_docs)
      recall = recall_score(gt_relevant_docs, inferred_relevant_docs)
      f1 = f1_score(gt_relevant_docs, inferred_relevant_docs)

      results.append({'query_id': query_id,
                      'k1': k1,
                      'b': b,
                      'score_threshold': score_threshold,
                      'precision': precision, 
                      'recall': recall, 
                      'f1': f1})

  results_df = pd.DataFrame(results)

  return results_df


In [None]:
k1_values = [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
b_values = [0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75]

In [None]:
all_results = []

for k1_value in k1_values:
  for b_value in b_values:
    results_df = compute_BM25(docs_tokens, 
                              queries_tokens, 
                              qrels_df, 
                              k1_value, 
                              b_value, 
                              score_threshold = 5)
    
    print("k1={}, b={}, mean precision={}, mean recall={}, mean f1={}".format(k1_value,
                                                                              b_value,
                                                                              results_df['precision'].mean(),
                                                                              results_df['recall'].mean(),
                                                                              results_df['f1'].mean()))
    
    all_results.append(results_df)

k1=1.0, b=0.2, mean precision=0.09510112422434318, mean recall=0.583446351714902, mean f1=0.11379923088107419
k1=1.0, b=0.25, mean precision=0.09576328469948428, mean recall=0.5834260920933909, mean f1=0.11395588253647035
k1=1.0, b=0.3, mean precision=0.09632017632584725, mean recall=0.5832158787201877, mean f1=0.11396464551766776
k1=1.0, b=0.35, mean precision=0.09543929295083055, mean recall=0.582439356184242, mean f1=0.11362739273892163
k1=1.0, b=0.4, mean precision=0.09570733996034833, mean recall=0.5812685529921969, mean f1=0.11380334154328979
k1=1.0, b=0.45, mean precision=0.09653116437695333, mean recall=0.5827875094961208, mean f1=0.11464107726268122
k1=1.0, b=0.5, mean precision=0.09678071498103105, mean recall=0.5835129366597365, mean f1=0.114977777425425
k1=1.0, b=0.55, mean precision=0.09704709461283612, mean recall=0.5824205774552492, mean f1=0.11537287435588107
k1=1.0, b=0.6, mean precision=0.09664963163449536, mean recall=0.5820369753784244, mean f1=0.11487717739368532
k

In [None]:
all_results_df = pd.concat(all_results)

In [None]:
all_results_stats_df = all_results_df.groupby(['k1', 'b'])[['precision', 'recall', 'f1']].mean().reset_index()

In [None]:
all_results_stats_df

Unnamed: 0,k1,b,precision,recall,f1
0,1.0000000000,0.2000000000,0.0951011242,0.5834463517,0.1137992309
1,1.0000000000,0.2500000000,0.0957632847,0.5834260921,0.1139558825
2,1.0000000000,0.3000000000,0.0963201763,0.5832158787,0.1139646455
3,1.0000000000,0.3500000000,0.0954392930,0.5824393562,0.1136273927
4,1.0000000000,0.4000000000,0.0957073400,0.5812685530,0.1138033415
...,...,...,...,...,...
127,2.0000000000,0.5500000000,0.0927194619,0.6002291063,0.1159164508
128,2.0000000000,0.6000000000,0.0927108227,0.5999591962,0.1158999193
129,2.0000000000,0.6500000000,0.0921164607,0.5990388733,0.1152181691
130,2.0000000000,0.7000000000,0.0916130079,0.6012427313,0.1152837174


In [None]:
all_results_stats_df['precision'].max()

0.09704709461283613

In [None]:
all_results_stats_df['recall'].max()

0.6036350821844166

In [None]:
all_results_stats_df[all_results_stats_df['recall'] == all_results_stats_df['recall'].max()]

Unnamed: 0,k1,b,precision,recall,f1
120,2.0,0.2,0.0925484499,0.6036350822,0.1164159114
