In [1]:
import os
import math
from collections import defaultdict

import ir_datasets
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm import tqdm

os.environ["IR_DATASETS_HOME"] = "/storage/kliffeup/kliffeup/allknower_data"

In [2]:
DATASET_PREFIX = "wikir"
DATASET_VERSION = "en1k"
STAGE = "validation"

In [3]:
wikir_dataset = ir_datasets.load(f"{DATASET_PREFIX}/{DATASET_VERSION}/{STAGE}")

In [4]:
wikir_dataset.docs_cls()._fields

('doc_id', 'text')

In [5]:
for doc in wikir_dataset.docs_iter()[:2]:
    print(doc)

GenericDoc(doc_id='1781133', text='it was used in landing craft during world war ii and is used today in private boats and training facilities the 6 71 is an inline six cylinder diesel engine the 71 refers to the displacement in cubic inches of each cylinder the firing order of the engine is 1 5 3 6 2 4 the engine s compression ratio is 18 7 1 with a 4 250 inch bore and a 5 00 inch stroke the engine weighs and is 54 inches long 29 inches wide and 41 inches tall at 2 100 revolutions per minute the engine is capable of producing 230 horse power 172 kilowatts v type versions of the 71 series were developed in 1957 the 6 71 is a two stroke engine as the engine will not naturally aspirate air is provided via a roots type blower however on the 6 71t models a turbocharger and a supercharger are utilized fuel is provided by unit injectors one per cylinder the amount of fuel injected into the engine is controlled by the engine s governor the engine cooling is via liquid in a water jacket in a b

In [6]:
query_idx = -1

for query in wikir_dataset.queries_iter():
    query_idx += 1
    if query_idx >= 5:
        break
    print(query)

GenericQuery(query_id='1402535', text='irish sea')
GenericQuery(query_id='91198', text='phillips exeter academy')
GenericQuery(query_id='1015979', text='president of chile')
GenericQuery(query_id='111134', text='university of kentucky')
GenericQuery(query_id='201459', text='johnnie to')


In [7]:
qrel_idx = -1

for qrel in wikir_dataset.qrels_iter():
    qrel_idx += 1
    if qrel_idx >= 5:
        break
    print(qrel)

TrecQrel(query_id='1402535', doc_id='1402535', relevance=2, iteration='0')
TrecQrel(query_id='1402535', doc_id='488489', relevance=1, iteration='0')
TrecQrel(query_id='1402535', doc_id='1813456', relevance=1, iteration='0')
TrecQrel(query_id='1402535', doc_id='1668563', relevance=1, iteration='0')
TrecQrel(query_id='1402535', doc_id='1232652', relevance=1, iteration='0')


# Feature engineering

In [8]:
MAX_RELEVANCE = 2

In [9]:
queries = pd.DataFrame(
    sorted(
        map(
            lambda item: (int(item.query_id), item.text),
            wikir_dataset.queries_iter()
        )  
    ),
    columns=["query_id", "query"],
)

In [10]:
docs = pd.DataFrame(
    sorted(
        map(
            lambda item: (int(item.doc_id), item.text),
            wikir_dataset.docs_iter()
        ),
        key=lambda item: item[0]
    ),
    columns=["doc_id", "doc"],
)

In [11]:
query_id_to_doc_id_rels = pd.DataFrame(
    sorted(
        map(
            lambda item: (int(item.query_id), int(item.doc_id), item.relevance / MAX_RELEVANCE),
            wikir_dataset.qrels_iter()
        ),
        key=lambda item: (item[0], item[1])
    ),
    columns=["query_id", "doc_id", "relevance"],
)

In [12]:
query_id_to_doc_id_bm25_scores = pd.DataFrame(
    sorted(
        map(
            lambda item: (int(item.query_id), int(item.doc_id), item.score),
            wikir_dataset.scoreddocs_iter()
        ),
        key=lambda item: (item[0], item[1])
    ),
    columns=["query_id", "doc_id", "bm25_score"],    
)

In [13]:
docs_count_per_queries = query_id_to_doc_id_rels.groupby("query_id")["doc_id"].count().values

In [14]:
data = query_id_to_doc_id_rels\
.merge(queries, how="left")\
.merge(docs, how="left")

In [15]:
data

Unnamed: 0,query_id,doc_id,relevance,query,doc
0,28,28,1.0,andorra,believed to have been created by charlemagne a...
1,28,23160,0.5,andorra,it is located high in the east pyrenees betwee...
2,28,103379,0.5,andorra,enric marfany bons composed the music while th...
3,28,201770,0.5,andorra,the archetypal alfajor entered iberia during t...
4,28,219989,0.5,andorra,the club currently plays in primera divisi and...
...,...,...,...,...,...
4974,2053270,1167065,0.5,occitanie,farmland woods meadows harmoniously closed by ...
4975,2053270,1176183,0.5,occitanie,formed in 1934 the club compete in the elite o...
4976,2053270,2053270,1.0,occitanie,the conseil d tat approved occitanie as the ne...
4977,2053270,2229244,0.5,occitanie,the castle overlooks the vicdessos valley and ...


# Recover BM25 scores

In [16]:
corpus_json = docs.apply(lambda row: dict(row), axis=1).values.tolist()

In [17]:
corpus_text = list(map(lambda item: item["doc"], corpus_json))

In [18]:
import bm25s
import Stemmer


stemmer = Stemmer.Stemmer("english")

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [19]:
tokenizer = bm25s.tokenization.Tokenizer(
    stemmer=stemmer,
    lower=True,  # lowercase the tokens
    stopwords="english",  # or pass a list of stopwords
    splitter=r"\w+",  # by default r"(?u)\b\w\w+\b", can also be a function
)

In [20]:
# Tokenize the corpus
corpus_tokens = tokenizer.tokenize(
    corpus_text, 
    update_vocab=True, # update the vocab as we tokenize
    return_as="ids"
)

Tokenize texts:   0%|          | 0/369721 [00:00<?, ?it/s]

In [21]:
retriever = bm25s.BM25(corpus=corpus_json)

In [22]:
retriever.index(corpus_tokens)

BM25S Create Vocab:   0%|          | 0/369721 [00:00<?, ?it/s]

BM25S Convert tokens to indices:   0%|          | 0/369721 [00:00<?, ?it/s]

BM25S Count Tokens:   0%|          | 0/369721 [00:00<?, ?it/s]

BM25S Compute Scores:   0%|          | 0/369721 [00:00<?, ?it/s]

# Calc BM25S scores

In [23]:
N_THREADS = 10
TOP_DOCS_BY_BM25_SCORE = 100
top_docs_by_bm25_score_per_query = []

In [24]:
for query_id, row in tqdm(data.groupby("query_id")):
    query = row["query"].values[0]

    query_tokens = tokenizer.tokenize(
        [query], 
        update_vocab=False,
        show_progress=False,
    )

    results, scores = retriever.retrieve(
        query_tokens, 
        k=TOP_DOCS_BY_BM25_SCORE, 
        show_progress=False,
        n_threads=N_THREADS,
    )

    top_docs_by_bm25_score_per_query.extend(
        [
            [query_id, query, doc_desc["doc_id"], doc_desc["doc"], doc_score]
            for doc_desc, doc_score in zip(results[0], scores[0])
        ]
    )

100%|████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 64.86it/s]


In [25]:
top_docs_by_bm25_score_per_query = pd.DataFrame(
    data=top_docs_by_bm25_score_per_query,
    columns=["query_id", "query", "doc_id", "doc", "bm25_score"],
)

In [26]:
data = data.merge(
    top_docs_by_bm25_score_per_query, 
    on=["query_id", "doc_id", "query", "doc"], 
    how="outer",
)

In [27]:
data["relevance"] = data["relevance"].fillna(0)

In [28]:
data["bm25_score"] = data.groupby("query_id")["bm25_score"].transform(lambda group: group.fillna(group.min()))

# Calc pos statistics

In [158]:
def calc_query_term_ngram_coverage(query_tokens: list[int], doc_tokens: list[int], n: int = 1) -> int | float:
    if n > min(len(query_tokens), len(doc_tokens)):
        return 0, 0

    query_ngram_counter = {}
    for i in range(len(query_tokens) - n + 1):
        query_ngram_counter[tuple(query_tokens[i:i + n])] = 0

    for i in range(len(doc_tokens) - n + 1):
        doc_ngram = tuple(doc_tokens[i:i + n])
        if doc_ngram in query_ngram_counter and query_ngram_counter[doc_ngram] == 0:
            query_ngram_counter[doc_ngram] = 1

    query_term_ngram_coverage = sum(query_ngram_counter.values())
    return query_term_ngram_coverage, query_term_ngram_coverage / len(query_ngram_counter)

In [160]:
calc_query_term_ngram_coverage([1, 2, 3], [1, 2, 4, 2, 4], n=3)

(0, 0.0)

In [161]:
def calc_doc_length(doc_tokens: list[int]) -> int:
    return len(doc_tokens)

In [162]:
def calc_query_length(query_tokens: list[int]) -> int:
    return len(query_tokens)

In [163]:
def calc_query_term_first_occurrence(query_tokens: list[int], doc_tokens: list[int], normalize: bool = False) -> int:
    first_occurrence = len(doc_tokens)
    
    for i, doc_token in enumerate(doc_tokens):
        if doc_token in query_tokens:
            first_occurrence = i
            break

    return first_occurrence, first_occurrence / len(doc_tokens)

In [164]:
calc_query_term_first_occurrence([1, 2, 3], [4, 1, 4, 2, 4], normalize=True)

(1, 0.2)

In [53]:
def calculate_idf(
    query_tokens: list[int], 
    docs_tokens: list[list[int]],
) -> dict[int, int]:
    query_tokens_counts_per_document = defaultdict(int)

    for doc_tokens in docs_tokens:
        doc_tokens_occurred_in_query = set()
        for token in doc_tokens:
            if token in query_tokens and token not in doc_tokens_occurred_in_query:
                query_tokens_counts_per_document[token] += 1
                doc_tokens_occurred_in_query.add(token)

    num_documents = len(docs_tokens)
    idf_dict = {}
    for token, doc_count in query_tokens_counts_per_document.items():
        idf_dict[token] = math.log(num_documents / (doc_count + 1)) + 1

    return idf_dict


def calculate_hh_proximity(
    query_tokens: list[int], 
    doc_tokens: list[list[int]], 
    idf_dict: dict[int, int], 
    z: float = 1.75,
) -> float:
    term_positions = defaultdict(list)
    for idx, token in enumerate(doc_tokens):
        if token in query_tokens:
            term_positions[token].append(idx)

    hh_proximity = 0

    for cur_term, cur_term_positions in term_positions.items():
        cur_term_idf = idf_dict.get(cur_term, 1)
        for cur_term_position in cur_term_positions:
            for other_term in query_tokens:
                if other_term == cur_term:
                    term_weight = 0.25
                else:
                    term_weight = 1
                    
                other_positions = term_positions.get(other_term, [])
                
                if other_positions:
                    lmd = float("inf")
                    rmd = float("inf")
                    other_term_idf = idf_dict.get(other_term, 1)
                    for other_position in other_positions:
                        if other_position < cur_term_position:
                            lmd = min(lmd, cur_term_position - other_position)
                        elif other_position > cur_term_position:
                            rmd = min(rmd, other_position - cur_term_position)
                            break
                    

                    hh_proximity += term_weight * cur_term_idf * ((other_term_idf / (lmd ** z)) + (other_term_idf / (rmd ** z)))

    return math.log(1 + hh_proximity)

In [49]:
query_tokens = ["data", "science"]
doc_tokens = ["data", "is", "the", "new", "science", "of", "data", "analysis"]
docs_tokens = [
    ["data", "is", "the", "new", "science", "of", "data", "analysis"],
    ["machine", "learning", "is", "a", "field", "of", "artificial", "intelligence"],
    ["data", "science", "involves", "statistics", "and", "programming"]
]

idf_dict = calculate_idf(query_tokens, docs_tokens)
print(idf_dict)
hh_proximity_score = calculate_hh_proximity(query_tokens, doc_tokens, idf_dict)
print(f"HHProximity Score: {hh_proximity_score}")

{'data': 1.0, 'science': 1.0}
HHProximity Score: 0.5839557466475039


In [42]:
query_length = []
doc_length = []
query_term_unigram_coverage_unnormalized = []
query_term_bigram_coverage_unnormalized = []
query_term_trigram_coverage_unnormalized = []
query_term_first_occurrence_unnormalized = []
query_term_last_occurrence_unnormalized = []
span_length_unnormalized = []
query_term_unigram_coverage_normalized = []
query_term_bigram_coverage_normalized = []
query_term_trigram_coverage_normalized = []
query_term_first_occurrence_normalized = []
query_term_last_occurrence_normalized = []
span_length_normalized_doc = []
span_length_normalized_query = []

In [43]:
doc_ids = docs["doc_id"].values.tolist()

In [44]:
for _, row in tqdm(queries.iterrows()):
    query, query_id = row.query, row.query_id
    
    doc_ids_query = data[data["query_id"] == query_id][["doc_id"]].values.flatten().tolist()

    query_tokens = tokenizer.tokenize(
        [query], 
        update_vocab=False,
        show_progress=False,
    )[0]

    docs_tokens_query = list(map(lambda doc_id: corpus_tokens[doc_ids.index(doc_id)], doc_ids_query))
    # docs_tokens_query = [corpus_tokens[doc_ids.index(doc_id)] for doc_id in doc_ids_query]
    query_length += [calc_query_length(query_tokens)] * len(doc_ids_query)

    for doc_tokens in docs_tokens_query:
        doc_length.append(calc_doc_length(doc_tokens))
        query_term_unigram_coverage_unnormalized.append(calc_query_term_ngram_coverage(query_tokens, doc_tokens, n=1))
        query_term_bigram_coverage_unnormalized.append(calc_query_term_ngram_coverage(query_tokens, doc_tokens, n=2))
        query_term_trigram_coverage_unnormalized.append(calc_query_term_ngram_coverage(query_tokens, doc_tokens, n=3))
        query_term_first_occurrence_unnormalized.append(calc_query_term_first_occurrence(query_tokens, doc_tokens))
        query_term_last_occurrence_unnormalized.append((doc_length[-1] - 1) - calc_query_term_first_occurrence(query_tokens, doc_tokens[::-1]))
        span_length_unnormalized.append(query_term_last_occurrence_unnormalized[-1] - query_term_first_occurrence_unnormalized[-1] + 1)
        query_term_unigram_coverage_normalized.append(calc_query_term_ngram_coverage(query_tokens, doc_tokens, n=1, normalize=True))
        query_term_bigram_coverage_normalized.append(calc_query_term_ngram_coverage(query_tokens, doc_tokens, n=2, normalize=True))
        query_term_trigram_coverage_normalized.append(calc_query_term_ngram_coverage(query_tokens, doc_tokens, n=3, normalize=True))
        query_term_first_occurrence_normalized.append(calc_query_term_first_occurrence(query_tokens, doc_tokens, normalize=True))
        query_term_last_occurrence_normalized.append(query_term_last_occurrence_unnormalized[-1] / doc_length[-1])
        span_length_normalized_doc.append(span_length_unnormalized[-1] / doc_length[-1])
        if query_length[-1]:
            span_length_normalized_query.append(span_length_unnormalized[-1] / query_length[-1])
        else:
            span_length_normalized_query.append(0)

1444it [05:41,  4.23it/s]


In [45]:
(
    len(doc_scores),
    len(query_length), 
    len(doc_length), 
    len(query_term_unigram_coverage_unnormalized),
    len(query_term_bigram_coverage_unnormalized),
    len(query_term_trigram_coverage_unnormalized),
    len(query_term_first_occurrence_unnormalized),
    len(query_term_last_occurrence_unnormalized),
    len(span_length_unnormalized),
    len(query_term_unigram_coverage_normalized),
    len(query_term_bigram_coverage_normalized),
    len(query_term_trigram_coverage_normalized),
    len(query_term_first_occurrence_normalized),
    len(query_term_last_occurrence_normalized),
    len(span_length_normalized_doc),
    len(span_length_normalized_query),
)

(185926,
 185926,
 185926,
 185926,
 185926,
 185926,
 185926,
 185926,
 185926,
 185926,
 185926,
 185926,
 185926,
 185926,
 185926,
 185926)

In [46]:
data["bm25s_score"] = doc_scores
data["query_length"] = query_length
data["doc_length"] = doc_length
data["query_term_unigram_coverage_unnormalized"] = query_term_unigram_coverage_unnormalized
data["query_term_bigram_coverage_unnormalized"] = query_term_bigram_coverage_unnormalized
data["query_term_trigram_coverage_unnormalized"] = query_term_trigram_coverage_unnormalized
data["query_term_first_occurrence_unnormalized"] = query_term_first_occurrence_unnormalized
data["query_term_last_occurrence_unnormalized"] = query_term_last_occurrence_unnormalized
data["span_length_unnormalized"] = span_length_unnormalized
data["query_term_unigram_coverage_normalized"] = query_term_unigram_coverage_normalized
data["query_term_bigram_coverage_normalized"] = query_term_bigram_coverage_normalized
data["query_term_trigram_coverage_normalized"] = query_term_trigram_coverage_normalized
data["query_term_first_occurrence_normalized"] = query_term_first_occurrence_normalized
data["query_term_last_occurrence_normalized"] = query_term_last_occurrence_normalized
data["span_length_normalized_doc"] = span_length_normalized_doc
data["span_length_normalized_query"] = span_length_normalized_query

In [47]:
data.head()

Unnamed: 0,query_id,doc_id,bm25_score,relevance,query,doc,bm25s_score,query_length,doc_length,query_term_unigram_coverage_unnormalized,...,query_term_first_occurrence_unnormalized,query_term_last_occurrence_unnormalized,span_length_unnormalized,query_term_unigram_coverage_normalized,query_term_bigram_coverage_normalized,query_term_trigram_coverage_normalized,query_term_first_occurrence_normalized,query_term_last_occurrence_normalized,span_length_normalized_doc,span_length_normalized_query
0,79,79,16.106074,1.0,actinopterygii,these actinopterygian fin rays attach directly...,6.407508,1,139,1,...,88,133,46,1.0,0.0,0.0,0.633094,0.956835,0.330935,46.0
1,79,12139,0.0,,actinopterygii,even the acts of launching discharging artille...,0.0,1,133,0,...,133,-1,-133,0.0,0.0,0.0,1.0,-0.007519,-1.0,-133.0
2,79,34714,0.0,,actinopterygii,state of nebraska as of the 2010 united states...,0.0,1,135,0,...,135,-1,-135,0.0,0.0,0.0,1.0,-0.007407,-1.0,-135.0
3,79,57100,0.0,,actinopterygii,as of the 2010 united states census the cdp s ...,0.0,1,146,0,...,146,-1,-146,0.0,0.0,0.0,1.0,-0.006849,-1.0,-146.0
4,79,92466,0.0,,actinopterygii,its lyrics were written by the buenos aires bo...,0.0,1,135,0,...,135,-1,-135,0.0,0.0,0.0,1.0,-0.007407,-1.0,-135.0


In [48]:
data = data.fillna(0)

In [49]:
data["doc_length_scaled"] = data.groupby("query_id")["doc_length"].transform(lambda x: (x - x.min()) / (x.max() - x.min()))
# df['Z_Score_Normalized'] = df.groupby('Group')['Value'].transform(lambda x: (x - x.mean()) / x.std())
# df['Mean_Normalized'] = df.groupby('Group')['Value'].transform(lambda x: (x - x.mean()) / (x.max() - x.min()))

In [53]:
# data.to_csv("wikir_en1k_training_preprocessed.csv", index=None)

In [51]:
data.columns

Index(['query_id', 'doc_id', 'bm25_score', 'relevance', 'query', 'doc',
       'bm25s_score', 'query_length', 'doc_length',
       'query_term_unigram_coverage_unnormalized',
       'query_term_bigram_coverage_unnormalized',
       'query_term_trigram_coverage_unnormalized',
       'query_term_first_occurrence_unnormalized',
       'query_term_last_occurrence_unnormalized', 'span_length_unnormalized',
       'query_term_unigram_coverage_normalized',
       'query_term_bigram_coverage_normalized',
       'query_term_trigram_coverage_normalized',
       'query_term_first_occurrence_normalized',
       'query_term_last_occurrence_normalized', 'span_length_normalized_doc',
       'span_length_normalized_query', 'doc_length_scaled'],
      dtype='object')

In [52]:
data[['query_id', 'doc_id', 'bm25_score', 'relevance',
       'bm25s_score', 'query_length', 'doc_length',
       'query_term_unigram_coverage_unnormalized',
       'query_term_bigram_coverage_unnormalized',
       'query_term_trigram_coverage_unnormalized',
       'query_term_first_occurrence_unnormalized',
       'query_term_last_occurrence_unnormalized', 'span_length_unnormalized',
       'query_term_unigram_coverage_normalized',
       'query_term_bigram_coverage_normalized',
       'query_term_trigram_coverage_normalized',
       'query_term_first_occurrence_normalized',
       'query_term_last_occurrence_normalized', 'span_length_normalized_doc',
       'span_length_normalized_query', 'doc_length_scaled']].to_csv(f"{DATASET_PREFIX}_{DATASET_VERSION}_{STAGE}_preprocessed_light.csv", index=None)