In [1]:
import os
import re
import math

from tqdm.auto import tqdm
from collections import defaultdict

import pandas as pd

## read data

In [2]:
def process_title(title: pd.Series):
    title = title.str.lower().str.strip()
    title = title.str.replace(r'([^\w\s])|_|-', '_', regex=True)
    return title

# def process_title2(title: pd.Series):
#     title = title.str.lower().str.strip()
#     title = title.str.replace(r"[:,!#'\(\);\.\/\-–]", '_', regex=True)
#     return title

### parsed texts from wiki raw html pages

In [3]:
wiki_dp = '/media/rtn/Windows 10/work/univier/wiki_extract/wiki_parsed'
fp = os.path.join(wiki_dp, 'filepaths.csv')
filemap = pd.read_csv(fp)
assert filemap.duplicated('filename').sum() == 0
print(filemap.shape)

(223619, 3)


In [4]:
filemap.head(2)

Unnamed: 0,filename,path,html_path
0,000_emergency,/media/rtn/Windows 10/work/univier/wiki_extrac...,/media/rtn/Windows 10/work/univier/wiki_extrac...
1,0s_bc,/media/rtn/Windows 10/work/univier/wiki_extrac...,/media/rtn/Windows 10/work/univier/wiki_extrac...


### domain data to narrow search
* 13112 / 15190 (or 86%) articles are found in our wiki extract. 
* check what articles are missed

In [5]:
domain_articles_raw = pd.read_csv('data/selected_docs.tsv', sep='\t', header=None).iloc[:, 0]
domain_articles = process_title(domain_articles_raw)
domain_articles.shape

(15190,)

In [6]:
# there are some duplicates in domain
domain_articles_raw[domain_articles.duplicated(keep=False)].sort_values()

4336                         Body_Double
14619                        Body_double
13476           Kennelly-Heaviside_layer
10031           Kennelly–Heaviside_layer
997      Voyage_To_The_Bottom_Of_The_Sea
3991     Voyage_to_the_Bottom_of_the_Sea
Name: 0, dtype: object

In [7]:
domain_articles.drop_duplicates(inplace=True)
domain_articles.shape

(15187,)

In [8]:
len(set(domain_articles).intersection(filemap['filename']))

13112

In [9]:
not_matched = pd.DataFrame({'title': list(set(domain_articles).difference(filemap['filename']))})
not_matched.shape

(2075, 1)

In [10]:
not_matched['is_alphanum'] = not_matched['title'].str.fullmatch(r'[a-z0-9_]+')
not_matched['is_alphanum'].value_counts()

True     1842
False     233
Name: is_alphanum, dtype: int64

In [11]:
not_matched.loc[~not_matched['is_alphanum'], 'title'].sample(10).tolist()

['the_beyoncé_experience__live',
 'roman_à_clef',
 'inō_tadataka',
 'tōkaidō_shinkansen',
 'álvaro_cunhal',
 'pokémon_black_2_and_white_2',
 'folie_à_deux',
 'pokémon_mystery_dungeon__blue_rescue_team_and_red_rescue_team',
 'jaraguá_do_sul',
 'rosé_wine']

In [12]:
not_matched.loc[not_matched['is_alphanum'], 'title'].sample(10).tolist()

['hayal_pass',
 'south_african_cheetah',
 '1_giant_leap',
 'manyas_spirlin',
 'king_s_cross_st__pancras_tube_station',
 'classical_unities',
 'rene_sylva',
 'die_deutschen_inschriften',
 'brenda',
 'dermal_bone']

In [13]:
filemap[filemap['filename'].str.startswith('cinnamon')]

Unnamed: 0,filename,path,html_path
44429,cinnamon,/media/rtn/Windows 10/work/univier/wiki_extrac...,/media/rtn/Windows 10/work/univier/wiki_extrac...


In [14]:
filemap[filemap['filename'].str.contains("'")]

Unnamed: 0,filename,path,html_path


### OK, let's proceed with what we have

In [15]:
domain_articles.shape

(15187,)

In [16]:
len(set(domain_articles).intersection(filemap['filename']))

13112

In [17]:
filemap_raw = filemap.copy()
filemap = filemap[filemap['filename'].isin(domain_articles)].copy()
print(filemap_raw.shape)
print(filemap.shape)

(223619, 3)
(13112, 3)


### read target

In [18]:
target = pd.read_csv('data/queries.tsv', sep='\t', header=None)
target.columns = ['query', 'title']
target['title'] = process_title(target['title'])

print(target.shape)
target.head(3)

(200, 2)


Unnamed: 0,query,title
0,animals that have shells and live in water,shell__zoology_
1,how many different types of scorpions are there,scorpion
2,describe the structure of a scientific name fo...,binomial_nomenclature


In [19]:
print(target.duplicated("title").sum())
print(target.duplicated().sum())

58
0


In [20]:
common = list(set(target['title']).intersection(filemap['filename']))
print(f'titles in intersection: {len(common)}')

target = target[target['title'].isin(common)].copy()
print(f'target.shape after filtering: {target.shape}')

titles in intersection: 137
target.shape after filtering: (194, 2)


## functions

In [21]:
def get_article_path(title):
    return filemap.query('filename == @title.lower().strip()')['path'].iloc[0]

def get_article_text(fp):
    with open(fp) as fin: text = fin.read()
    return text

In [22]:
def tokenize(text):
    text = text.lower()
    text = re.sub(r'[^\w\s\-]+', '', text)
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    tokens = text.split(' ')
    return tokens

In [23]:
def accuracy(titles_true, title_predictions, k=1):
    """ Assert title_predictions are sorted descending. """
    assert len(titles_true) == len(title_predictions)
    acc = 0
    for t_true, t_preds in zip(titles_true, title_predictions):
        acc += t_true in t_preds[:k]
    return acc / len(titles_true)

def mean_reciprocal_rank(titles_true, title_predictions, k=1):
    """ Assert title_predictions are sorted descending. """
    assert len(titles_true) == len(title_predictions)
    rr = 0
    for t_true, t_preds in zip(titles_true, title_predictions):
        for i in range(k):
            if t_true == t_preds[i]:
                rr += 1 / (i + 1)
                break
    return rr / len(titles_true)

In [24]:
def evaluate(predict_fn, queries, titles_true, **kwargs):
    preds = []
    for query in tqdm(queries):
        q_tokens = tokenize(query)
        q_preds = predict_fn(q_tokens, **kwargs)
        q_preds = [p[0] for p in q_preds]
        preds.append(q_preds)
    
    acc1 = accuracy(titles_true=titles_true, title_predictions=preds, k=1)
    acc10 = accuracy(titles_true=titles_true, title_predictions=preds, k=10)
    mrr10 = mean_reciprocal_rank(titles_true=titles_true, title_predictions=preds, k=10)
    return dict(acc1=acc1, acc10=acc10, mrr10=mrr10, n_queries=len(queries))

In [25]:
class BM25:

    def __init__(self):
        self.n_articles = 0
        self.article_token_cnt = None       # article -> dict with token count in this article
        self.inverted_index = None          # token   -> set of article titles with this token
        self.n_articles_w_token = None      # token   -> number of articles with this token
        self.article_len = None             # article -> number of tokens in article

    def fit(self, titles):
        assert len(titles) == len(set(titles))
        self.n_articles = len(titles)
        self.article_token_cnt = dict()
        self.inverted_index = defaultdict(set)
        self.article_len = dict()

        for title in titles:
            text = get_article_text(get_article_path(title))
            tokens = tokenize(text)
            self.article_len[title] = len(tokens)
            local_article_token_cnt = defaultdict(int)
            for tok in tokens:
                local_article_token_cnt[tok] += 1
                self.inverted_index[tok].add(title)
            self.article_token_cnt[title] = local_article_token_cnt

        self.mean_article_len = sum(self.article_len.values()) / self.n_articles
        self.n_articles_w_token = {tok: len(articles) for tok, articles in self.inverted_index.items()}

    def predict(self, q_tokens, k1, k2, b, top_k=10):
        query_token_cnt = defaultdict(int)
        for tok in q_tokens:
            query_token_cnt[tok] += 1

        article_score = defaultdict(float)
        rel_article_titles = set()

        for tok in q_tokens:
            if self.n_articles_w_token.get(tok, 0) == 0:  # use get to save memory accessing defaultdict
                continue

            rel_article_titles.update(self.inverted_index[tok])
            for title in rel_article_titles:
                K = k1 * (1 - b) + k1 * b * self.article_len[title] / self.mean_article_len
                article_cnt = self.article_token_cnt[title].get(tok, 0)  # use get to save memory accessing defaultdict
                x = math.log(self.n_articles + 1) - math.log(self.n_articles_w_token[tok])
                x *= (k1 + 1) * article_cnt / (K + article_cnt)
                x *= (k2 + 1) * query_token_cnt[tok] / (k2 + query_token_cnt[tok])
                article_score[title] += x

        article_score = sorted(article_score.items(), key=lambda x: x[1], reverse=True)
        article_score = article_score[:top_k]
        return article_score

    def predict_tfidf(self, q_tokens, top_k=10):
        """ Use simple tf-idf features to predict. """
        query_token_cnt = defaultdict(int)
        for tok in q_tokens:
            query_token_cnt[tok] += 1

        article_score = defaultdict(float)
        rel_article_titles = set()

        for tok in q_tokens:
            if self.n_articles_w_token.get(tok, 0) == 0:  # use get to save memory accessing defaultdict
                continue

            rel_article_titles.update(self.inverted_index[tok])
            for title in rel_article_titles:
                article_cnt = self.article_token_cnt[title].get(tok, 0)  # use get to save memory accessing defaultdict
                x = math.log(self.n_articles + 1) - math.log(self.n_articles_w_token[tok])
                x *= math.log(article_cnt + 1)
                article_score[title] += x

        article_score = sorted(article_score.items(), key=lambda x: x[1], reverse=True)
        article_score = article_score[:top_k]
        return article_score

### example

In [26]:
title = filemap.sample()['filename'].iloc[0]
print(f'title: "{title}"')
text = get_article_text(get_article_path(title))
tokens = tokenize(text)
print(text)
print(tokens)

title: "deluge_myth"
A Deluge myth or Flood myth is a mythical story about a flood. Usually this flood is sent by a deity to destroy a civilisation as a punishment. The theme can be found in many cultures. Well-known examples that are believed to be myths include the story of Noah's Ark in the Bible, the Hindu Puranic story of Manu, Deucalion in Greek mythology or Utnapishtim in the Epic of Gilgamesh.
['a', 'deluge', 'myth', 'or', 'flood', 'myth', 'is', 'a', 'mythical', 'story', 'about', 'a', 'flood', 'usually', 'this', 'flood', 'is', 'sent', 'by', 'a', 'deity', 'to', 'destroy', 'a', 'civilisation', 'as', 'a', 'punishment', 'the', 'theme', 'can', 'be', 'found', 'in', 'many', 'cultures', 'well-known', 'examples', 'that', 'are', 'believed', 'to', 'be', 'myths', 'include', 'the', 'story', 'of', 'noahs', 'ark', 'in', 'the', 'bible', 'the', 'hindu', 'puranic', 'story', 'of', 'manu', 'deucalion', 'in', 'greek', 'mythology', 'or', 'utnapishtim', 'in', 'the', 'epic', 'of', 'gilgamesh']


## fit model

In [27]:
# titles = filemap['filename'].sample(2000).tolist()
titles = filemap['filename'].tolist()

bm25 = BM25()
bm25.fit(titles)

print(f'# of articles model was fit on:\t{bm25.n_articles}')
print(f'# of tokens in fitted model:\t{len(bm25.n_articles_w_token)}')

# of articles model was fit on:	13112
# of tokens in fitted model:	187248


## evaluate bm25 (default params) vs tfidf

In [28]:
for query in ["coronovirus in belarus",
              "who won junior eurovision in 2005",
              "science about full-text search",
             ]:
    q_tokens = tokenize(query)
    result = bm25.predict(q_tokens, k1=1, k2=1, b=1)[:5]
    print(f"[{query}]")
    for article_name, score in result:
        print(f"{score:7.2f}  {article_name}")
    print("\n")

[coronovirus in belarus]
  10.89  covid_19_pandemic_in_belarus
  10.51  daugava_river
   9.28  junior_eurovision_song_contest_2015
   9.19  junior_eurovision_song_contest_2014
   8.88  bug_river


[who won junior eurovision in 2005]
  22.52  eurovision__europe_shine_a_light
  20.61  junior_eurovision_song_contest_2004
  17.19  junior_eurovision_song_contest_2019
  15.44  blue__group_
  14.15  darin_zanyar


[science about full-text search]
  17.41  information_retrieval
  13.63  optimization__disambiguation_
  11.04  citizen_science
  10.31  ask_com
  10.29  monty_python_and_the_holy_grail




In [29]:
for query in ["coronovirus in belarus",
              "who won junior eurovision in 2005",
              "science about full-text search",
             ]:
    q_tokens = tokenize(query)
    result = bm25.predict_tfidf(q_tokens)[:5]
    print(f"[{query}]")
    for article_name, score in result:
        print(f"{score:7.2f}  {article_name}")
    print("\n")

[coronovirus in belarus]
  10.43  covid_19_pandemic_in_belarus
   8.54  poland
   8.40  nuclear_accident
   8.32  list_of_ice_hockey_leagues
   7.84  junior_eurovision_song_contest_2014


[who won junior eurovision in 2005]
  27.20  list_of_people_from_texas
  26.87  2011_australian_open
  26.67  rockefeller_family
  26.35  jackie_robinson
  25.29  list_of_dinosaurs


[science about full-text search]
  23.22  google_search
  16.10  science_fiction
  15.50  ursula_k__le_guin
  15.10  information_retrieval
  14.57  philosophy_of_science




In [30]:
queries = target['query'].tolist()
titles_true = target['title'].tolist()

In [31]:
bm25_metrics = evaluate(bm25.predict, queries, titles_true, k1=1, k2=1, b=1, top_k=10)
tfidf_metrics = evaluate(bm25.predict_tfidf, queries, titles_true, top_k=10)
metrics = pd.DataFrame.from_records([bm25_metrics, tfidf_metrics], index=['bm25', 'tfidf'])
metrics

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

Unnamed: 0,acc1,acc10,mrr10,n_queries
bm25,0.185567,0.474227,0.271167,194
tfidf,0.030928,0.190722,0.079287,194


* we see that bm25 performs better than tfidf. 
  perhaps because:
  * bm25 uses term frequencies in query document
  * and weights are added to term frequencies in article document

## hyperparameters search

* as long as we have small evaluation set (<= 200 records) we won't split it into validation, test subsets.
  we'll tune parameters on the whole sample

In [34]:
metrics_hpsearch = []

for b in [0.6, 0.8, 1]:
    for k1 in [0.8, 1, 1.2, 1.5]:
        for k2 in [0.8, 1, 1.3]:
            m = evaluate(bm25.predict, queries, titles_true, k1=k1, k2=k2, b=b, top_k=10)
            m.update(b=b, k1=k1, k2=k2)
            metrics_hpsearch.append(m)

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

In [37]:
metrics_hpsearch_df = pd.DataFrame.from_records(
    metrics_hpsearch, 
    index=list(range(len(metrics_hpsearch)))
)
metrics_hpsearch_df.sort_values('mrr10', ascending=False)

Unnamed: 0,acc1,acc10,mrr10,n_queries,b,k1,k2
22,0.226804,0.505155,0.313157,194,0.8,1.5,1.0
23,0.226804,0.505155,0.313028,194,0.8,1.5,1.3
21,0.226804,0.505155,0.312332,194,0.8,1.5,0.8
9,0.21134,0.5,0.299812,194,0.6,1.5,0.8
10,0.21134,0.5,0.299689,194,0.6,1.5,1.0
11,0.21134,0.5,0.299654,194,0.6,1.5,1.3
20,0.21134,0.494845,0.296482,194,0.8,1.2,1.3
18,0.21134,0.494845,0.296439,194,0.8,1.2,0.8
19,0.21134,0.494845,0.296367,194,0.8,1.2,1.0
6,0.190722,0.479381,0.287555,194,0.6,1.2,0.8


In [39]:
for m in ['mrr10', 'acc10', 'acc1']:
    print(f'best params for "{m}" metric:')
    display(metrics_hpsearch_df.loc[[metrics_hpsearch_df[m].idxmax()]])

best params for "mrr10" metric:


Unnamed: 0,acc1,acc10,mrr10,n_queries,b,k1,k2
22,0.226804,0.505155,0.313157,194,0.8,1.5,1.0


best params for "acc10" metric:


Unnamed: 0,acc1,acc10,mrr10,n_queries,b,k1,k2
21,0.226804,0.505155,0.312332,194,0.8,1.5,0.8


best params for "acc1" metric:


Unnamed: 0,acc1,acc10,mrr10,n_queries,b,k1,k2
21,0.226804,0.505155,0.312332,194,0.8,1.5,0.8


* `b = 0.8, k1 = 1.5, k2 = 1.0` are best params on specified grid in terms of MRR@10, Accuracy@10, Accuracy

### compare bm25 search with best and default params

In [42]:
for query in ["coronovirus in belarus",
              "who won junior eurovision in 2005",
              "science about full-text search",
              'music'
             ]:
    q_tokens = tokenize(query)
    result = bm25.predict(q_tokens, k1=1.5, k2=1.0, b=0.8)[:10]
    print(f"[{query}]")
    for article_name, score in result:
        print(f"{score:7.2f}  {article_name}")
    print("\n")

[coronovirus in belarus]
  12.74  covid_19_pandemic_in_belarus
  11.40  daugava_river
  10.38  junior_eurovision_song_contest_2015
  10.26  junior_eurovision_song_contest_2014
   8.75  bug_river
   8.34  government_in_exile
   8.30  byelorussian_soviet_socialist_republic
   8.27  covid_19_pandemic_in_turkmenistan
   8.09  jagiellon_dynasty
   7.32  boy


[who won junior eurovision in 2005]
  24.87  eurovision__europe_shine_a_light
  21.22  junior_eurovision_song_contest_2004
  16.82  junior_eurovision_song_contest_2019
  15.25  darin_zanyar
  15.19  eurovision_song_contest_2017
  15.09  blue__group_
  13.68  tine_thing_helseth
  13.66  eurovision_song_contest_2007
  13.53  goodbye
  13.44  eurovision_song_contest_2011


[science about full-text search]
  18.04  information_retrieval
  14.18  optimization__disambiguation_
  11.74  citizen_science
  11.30  ask_com
  11.12  monty_python_and_the_holy_grail
  10.86  google_search
  10.12  stranger_things
   9.95  chemical_database
   9.88  

In [44]:
for query in ["coronovirus in belarus",
              "who won junior eurovision in 2005",
              "science about full-text search",
              'music'
             ]:
    q_tokens = tokenize(query)
    result = bm25.predict(q_tokens, k1=1, k2=1, b=1)[:10]
    print(f"[{query}]")
    for article_name, score in result:
        print(f"{score:7.2f}  {article_name}")
    print("\n")

[coronovirus in belarus]
  10.89  covid_19_pandemic_in_belarus
  10.51  daugava_river
   9.28  junior_eurovision_song_contest_2015
   9.19  junior_eurovision_song_contest_2014
   8.88  bug_river
   8.44  government_in_exile
   8.39  byelorussian_soviet_socialist_republic
   8.37  covid_19_pandemic_in_turkmenistan
   8.16  jagiellon_dynasty
   6.89  eastern_europe


[who won junior eurovision in 2005]
  22.52  eurovision__europe_shine_a_light
  20.61  junior_eurovision_song_contest_2004
  17.19  junior_eurovision_song_contest_2019
  15.44  blue__group_
  14.15  darin_zanyar
  13.96  eurovision_song_contest_2017
  13.37  tine_thing_helseth
  13.24  junior_eurovision_song_contest_2015
  12.93  junior_eurovision_song_contest_2014
  12.87  goodbye


[science about full-text search]
  17.41  information_retrieval
  13.63  optimization__disambiguation_
  11.04  citizen_science
  10.31  ask_com
  10.29  monty_python_and_the_holy_grail
  10.17  stranger_things
  10.05  the_saga_of_the_viking_wo