<a href="https://colab.research.google.com/github/eduseiti/bm25_explore/blob/main/bm25_ranking_with_CISI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install rank_bm25

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank_bm25
Successfully installed rank_bm25-0.2.2


In [2]:
import rank_bm25
import pandas as pd
import numpy as np

import nltk
from nltk.corpus import stopwords
import string

import os
import sys
import pickle

import regex as re
import urllib

In [19]:
from scipy import stats

In [3]:
nltk.download('stopwords')
nltk.download('punkt')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [4]:
PARSING_FIELDS_REGEXS={
  'identifier': ["^\.I\s+([0-9]+)"],
  'title_or_words': ["^\.T\s*\r?\n?$|^\.W\s*\r?\n?$"],
  'title_content_or_author': ["^\.A\s*\r?\n?$", "^(.+)\r\n$|^(.+)\n$"],
  'author_content_or_words': ["^\.W\s*\r?\n?$", "^(.+)\r\n$|^(.+)\n$"],
  'words_content_or_xref_or_identifier': ["^\.I\s+([0-9]+)", "^\.X\s*\r?\n?$", "^(.+)\r\n$|^(.+)\n$"],
  'words_content_or_identifier': ["^\.I\s+([0-9]+)", "^(.+)\r\n$|^(.+)\n$"],
  'xref_content_or_identifier': ["^\.I\s+([0-9]+)", "^(.+)\r\n$|^(.+)\n$"]
}

def read_cisi_docs_and_queries(file_url):

  all_elements = []

  current_element = {'next_field': 'identifier'}

  for line in urllib.request.urlopen(file_url):

    # print(line)
    # print(current_element)

    regex_list = PARSING_FIELDS_REGEXS[current_element['next_field']]

    for each_regex in regex_list:
      m = re.match(each_regex, line.decode())

      if m is not None:
        break;

    if m is not None:

      # Check if this match has data to store

      if len(m.groups()) > 0:

        # As there is data, check what it is and store it properly

        if current_element['next_field'] == 'identifier':
          current_element['identifier'] = m.group(1)
          current_element['next_field'] = 'title_or_words'

        elif current_element['next_field'] == 'title_content_or_author':
          current_element['title'] += m.group(1) + ' '

        elif current_element['next_field'] == 'author_content_or_words':
          current_element['author'] += m.group(1) + ';'

        elif (current_element['next_field'] == 'xref_content_or_identifier') or \
             (current_element['next_field'] == 'words_content_or_identifier') or \
             (current_element['next_field'] == 'words_content_or_xref_or_identifier'):

          if m.group(0)[0:2] == '.I':
            # Document complete
            
            # print(current_element)

            all_elements.append(current_element)
          
            current_element = {'identifier': m.group(1),
                               'next_field': 'title_or_words'}
          else:
            if current_element['next_field'] == 'xref_content_or_identifier':
              current_element['xref'] += m.group(1) + ';'
            else:
              if m.group(1) is not None:
                current_element['words'] += m.group(1) + ' '

      else:

        # This is a tag-only entry

        if current_element['next_field'] == 'title_or_words':
          if m.group(0)[0:2] == '.T':
            current_element['title'] = ""
            current_element['next_field'] = 'title_content_or_author'
          else:
            current_element['words'] = ""
            current_element['next_field'] = 'words_content_or_identifier'

        elif current_element['next_field'] == 'title_content_or_author':
          current_element['author'] = ""
          current_element['next_field'] = 'author_content_or_words'

        elif current_element['next_field'] == 'author_content_or_words':
          current_element['words'] = ""
          current_element['next_field'] = 'words_content_or_xref_or_identifier'

        elif current_element['next_field'] == 'words_content_or_xref_or_identifier':
          current_element['xref'] = ""
          current_element['next_field'] = 'xref_content_or_identifier'

        elif (current_element['next_field'] == 'xref_content_or_identifier') \
             (current_element['next_field'] == 'words_content_or_identifier'):
          
          # Document complete
          
          # print(current_element)

          all_elements.append(current_element)

          current_element = {'next_field': 'identifier'}

        else:
          print("Just ignore the line")

    # break

  if current_element['next_field'] != 'title_or_words':
    all_elements.append(current_element)

  print("Parsed {} elements...".format(len(all_elements)))  

  return pd.DataFrame(all_elements).drop(columns='next_field')

# Read CISI files

In [5]:
docs_df = read_cisi_docs_and_queries('https://raw.githubusercontent.com/eduseiti/bm25_explore/main/cisi/CISI.ALL')

Parsed 1460 elements...


In [6]:
queries_df = read_cisi_docs_and_queries('https://raw.githubusercontent.com/eduseiti/bm25_explore/main/cisi/CISI.QRY')

Parsed 112 elements...


In [25]:
qrels_df = pd.read_csv('https://raw.githubusercontent.com/eduseiti/bm25_explore/main/cisi/CISI.REL', 
                       sep='\t', 
                       header=None, 
                       names=['query_id', 'doc_id', 'Q0', 'rel'])

# Tokenize and clean stop words from reference docs and queries

In [8]:
def tokenize_and_remove_stop_words(which_df, stop_words, punctuation):

  all_tokens = [nltk.word_tokenize(doc.lower()) for doc in which_df['words']]

  cleaned_tokens = [[token for token in doc_tokens if token not in stop_words and token not in punctuation] for doc_tokens in all_tokens]

  return cleaned_tokens

In [9]:
stop_words = set(stopwords.words('english'))
punctuation = set(string.punctuation)

In [10]:
docs_tokens = tokenize_and_remove_stop_words(docs_df, stop_words, punctuation)

In [11]:
len(docs_tokens)

1460

In [12]:
queries_tokens = tokenize_and_remove_stop_words(queries_df, stop_words, punctuation)

In [13]:
len(queries_tokens)

112

# Compute BM25 scores for each query / document pair

In [14]:
docs_bm25_scores = rank_bm25.BM25Okapi(docs_tokens)

In [15]:
docs_queries_scores = []

for query_tokens in queries_tokens:
  query_scores = docs_bm25_scores.get_scores(query_tokens)

  docs_queries_scores.append(query_scores)

In [16]:
len(docs_queries_scores)

112

In [20]:
stats.describe(docs_queries_scores[0])

DescribeResult(nobs=1460, minmax=(0.0, 26.730362664515948), mean=1.949901965823141, variance=12.167242493146299, skewness=2.4218850913284964, kurtosis=7.038081458985664)

In [24]:
np.histogram(docs_queries_scores[100])

(array([463, 434, 222, 142, 104,  57,  22,  12,   2,   2]),
 array([ 0.        ,  4.8035264 ,  9.60705281, 14.41057921, 19.21410562,
        24.01763202, 28.82115843, 33.62468483, 38.42821124, 43.23173764,
        48.03526405]))

In [None]:
aa

In [None]:
docs_queries_scores[0][1280]

In [None]:
docs_queries_scores[0][docs_queries_scores[0] > 9]

In [32]:
np.where(docs_queries_scores[0] > 9)[0] + 1

array([  17,   28,   34,   38,   65,   76,  106,  135,  150,  192,  193,
        196,  201,  204,  212,  215,  219,  221,  225,  227,  234,  244,
        269,  415,  429,  440,  447,  449,  465,  466,  477,  483,  485,
        493,  495,  510,  524,  546,  573,  576,  582,  589,  604,  609,
        616,  622,  650,  676,  711,  722,  726,  757,  759,  767,  790,
        804,  811,  813,  814,  820,  831,  861,  863,  869,  886,  901,
        920,  921,  953,  958, 1055, 1059, 1089, 1090, 1091, 1118, 1160,
       1162, 1164, 1195, 1197, 1281, 1286, 1299, 1323, 1338, 1364, 1369,
       1373, 1383, 1387, 1436, 1440])

# Now compare the computed relevance score with the groud truth

In [26]:
from sklearn.metrics import precision_score, recall_score, f1_score

In [57]:
score_threshold = 1e-5

results = []

# Evaluate the retrieval performance using precision, recall, and F1-score

query_ids = qrels_df['query_id'].unique()

for query_id in query_ids:

    inferred_relevant_docs = docs_queries_scores[query_id - 1] > score_threshold

    gt_relevant_docs = np.zeros(inferred_relevant_docs.shape[0], dtype=bool)
    gt_relevant_docs[qrels_df[qrels_df['query_id'] == query_id]['doc_id'].to_numpy() - 1] = True

    precision = precision_score(gt_relevant_docs, inferred_relevant_docs)
    recall = recall_score(gt_relevant_docs, inferred_relevant_docs)
    f1 = f1_score(gt_relevant_docs, inferred_relevant_docs)

    results.append({'query_id': query_id, 'precision': precision, 'recall': recall, 'f1': f1})

results_df = pd.DataFrame(results)

In [58]:
results_df.describe()

Unnamed: 0,query_id,precision,recall,f1
count,76.0,76.0,76.0,76.0
mean,46.776316,0.038334,0.885,0.070852
std,32.293074,0.033457,0.119853,0.057963
min,1.0,0.00073,0.580247,0.00146
25%,19.75,0.011838,0.815559,0.023357
50%,41.5,0.025024,0.909384,0.048563
75%,69.5,0.05383,1.0,0.101087
max,111.0,0.132548,1.0,0.228443
