In [46]:
import os

import ir_datasets
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm import tqdm


os.environ["IR_DATASETS_HOME"] = "/storage/kliffeup/kliffeup/allknower_data"

In [47]:
DATASET_PREFIX = "wikir"
DATASET_VERSION = "en78k"
STAGE = "training"

In [48]:
wikir_dataset = ir_datasets.load(f"{DATASET_PREFIX}/{DATASET_VERSION}/{STAGE}")

In [3]:
wikir_dataset.docs_cls()._fields

('doc_id', 'text')

In [4]:
for doc in wikir_dataset.docs_iter()[:2]:
    print(doc)

GenericDoc(doc_id='1781133', text='it was used in landing craft during world war ii and is used today in private boats and training facilities the 6 71 is an inline six cylinder diesel engine the 71 refers to the displacement in cubic inches of each cylinder the firing order of the engine is 1 5 3 6 2 4 the engine s compression ratio is 18 7 1 with a 4 250 inch bore and a 5 00 inch stroke the engine weighs and is 54 inches long 29 inches wide and 41 inches tall at 2 100 revolutions per minute the engine is capable of producing 230 horse power 172 kilowatts v type versions of the 71 series were developed in 1957 the 6 71 is a two stroke engine as the engine will not naturally aspirate air is provided via a roots type blower however on the 6 71t models a turbocharger and a supercharger are utilized fuel is provided by unit injectors one per cylinder the amount of fuel injected into the engine is controlled by the engine s governor the engine cooling is via liquid in a water jacket in a b

In [5]:
query_idx = -1

for query in wikir_dataset.queries_iter():
    query_idx += 1
    if query_idx >= 5:
        break
    print(query)

GenericQuery(query_id='123839', text='yanni')
GenericQuery(query_id='188629', text='k pop')
GenericQuery(query_id='13898', text='venice film festival')
GenericQuery(query_id='316959', text='downtown brooklyn')
GenericQuery(query_id='515031', text='pennsylvania house of representatives')


In [6]:
qrel_idx = -1

for qrel in wikir_dataset.qrels_iter():
    qrel_idx += 1
    if qrel_idx >= 5:
        break
    print(qrel)

TrecQrel(query_id='123839', doc_id='123839', relevance=2, iteration='0')
TrecQrel(query_id='123839', doc_id='1793430', relevance=1, iteration='0')
TrecQrel(query_id='123839', doc_id='806300', relevance=1, iteration='0')
TrecQrel(query_id='123839', doc_id='806075', relevance=1, iteration='0')
TrecQrel(query_id='123839', doc_id='836567', relevance=1, iteration='0')


In [7]:
bm25_score_idx = -1

for bm25_score in wikir_dataset.scoreddocs_iter():
    bm25_score_idx += 1
    if bm25_score_idx >= 5:
        break
    print(bm25_score)

GenericScoredDoc(query_id='123839', doc_id='806300', score=20.720094194011075)
GenericScoredDoc(query_id='123839', doc_id='123839', score=19.91782871489318)
GenericScoredDoc(query_id='123839', doc_id='836567', score=18.824522997710037)
GenericScoredDoc(query_id='123839', doc_id='806326', score=18.824522997710037)
GenericScoredDoc(query_id='123839', doc_id='806075', score=17.246712972547066)


# Feature engineering

In [8]:
MAX_RELEVANCE = 2

In [9]:
queries = pd.DataFrame(
    sorted(
        map(
            lambda item: (int(item.query_id), item.text),
            wikir_dataset.queries_iter()
        )  
    ),
    columns=["query_id", "query"],
)

In [10]:
docs = pd.DataFrame(
    sorted(
        map(
            lambda item: (int(item.doc_id), item.text),
            wikir_dataset.docs_iter()
        ),
        key=lambda item: item[0]
    ),
    columns=["doc_id", "doc"],
)

In [11]:
query_id_to_doc_id_rels = pd.DataFrame(
    sorted(
        map(
            lambda item: (int(item.query_id), int(item.doc_id), item.relevance / MAX_RELEVANCE),
            wikir_dataset.qrels_iter()
        ),
        key=lambda item: (item[0], item[1])
    ),
    columns=["query_id", "doc_id", "relevance"],
)

In [12]:
query_id_to_doc_id_bm25_scores = pd.DataFrame(
    sorted(
        map(
            lambda item: (int(item.query_id), int(item.doc_id), item.score),
            wikir_dataset.scoreddocs_iter()
        ),
        key=lambda item: (item[0], item[1])
    ),
    columns=["query_id", "doc_id", "bm25_score"],    
)

In [13]:
data = query_id_to_doc_id_bm25_scores\
.merge(query_id_to_doc_id_rels, how="outer")\
.merge(queries, how="left")\
.merge(docs, how="left")

In [14]:
data

Unnamed: 0,query_id,doc_id,bm25_score,relevance,query,doc
0,79,79,16.106074,1.0,actinopterygii,these actinopterygian fin rays attach directly...
1,79,12139,0.000000,,actinopterygii,even the acts of launching discharging artille...
2,79,34714,0.000000,,actinopterygii,state of nebraska as of the 2010 united states...
3,79,57100,0.000000,,actinopterygii,as of the 2010 united states census the cdp s ...
4,79,92466,0.000000,,actinopterygii,its lyrics were written by the buenos aires bo...
...,...,...,...,...,...,...
185921,2433785,2433250,13.206970,0.5,dhanbad sadar subdivision,one of the many spurs of pareshnath hill 1 365...
185922,2433785,2433522,15.420900,0.5,dhanbad sadar subdivision,the damodar river the most important river of ...
185923,2433785,2433549,9.231121,0.5,dhanbad sadar subdivision,while the damodar flows along the southern bou...
185924,2433785,2433785,37.233094,1.0,dhanbad sadar subdivision,initially the district was split into two subd...


# Recover BM25 scores

In [15]:
import bm25s
import Stemmer

stemmer = Stemmer.Stemmer("english")

  from .autonotebook import tqdm as notebook_tqdm
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [16]:
for _, row in docs.iterrows():
    print(dict(row))
    break

{'doc_id': 5, 'doc': 'he was the son of the nereid thetis and peleus king of phthia achilles most notable feat during the trojan war was the slaying of the trojan prince hector outside the gates of troy although the death of achilles is not presented in the iliad other sources concur that he was killed near the end of the trojan war by paris who shot him in the heel with an arrow later legends beginning with statius unfinished epic achilleid written in the 1st century ad state that achilles was invulnerable in all of his body except for his heel because when his mother thetis dipped him in the river styx as an infant she held him by one of his heels alluding to these legends the term achilles heel has come to mean a point of weakness especially in someone or something with an otherwise strong constitution the achilles tendon is also named after him due to these legends linear b tablets attest to the personal name achilleus in the forms a ki re u and a ki re we the latter being the dati

# Calc scores over whole doc corpus

In [17]:
corpus_json = docs.apply(lambda row: dict(row), axis=1).values.tolist()

In [18]:
corpus_text = list(map(lambda item: item["doc"], corpus_json))

In [19]:
tokenizer = bm25s.tokenization.Tokenizer(
    stemmer=stemmer,
    lower=True, # lowercase the tokens
    stopwords="english",  # or pass a list of stopwords
    splitter=r"\w+",  # by default r"(?u)\b\w\w+\b", can also be a function
)

In [20]:
# Tokenize the corpus
corpus_tokens = tokenizer.tokenize(
    corpus_text, 
    update_vocab=True, # update the vocab as we tokenize
    return_as="ids"
)

# corpus_tokens = bm25s.tokenize(corpus_text, 
#                                # stopwords="en", stemmer=stemmer
#                               )

                                                                                                                                                                              

In [21]:
retriever = bm25s.BM25(corpus=corpus_json)

In [22]:
retriever.index(corpus_tokens)

                                                                                                                                                                              

In [23]:
query = "actinopterygii"

In [24]:
query_tokens = tokenizer.tokenize(
    [query], 
    update_vocab=False
)

# query_tokens = bm25s.tokenize(query,
#                               stopwords="en", stemmer=stemmer
#                              )

                                                                                                                                                                              

In [25]:
# query_tokens

[[850]]

In [26]:
# docs_79_query = data[data["query_id"] == 79][["doc_id", "doc"]]

In [27]:
# corpus_79_json = docs_79_query.apply(lambda row: dict(row), axis=1).values.tolist()

In [28]:
# corpus_79_text = list(map(lambda item: item["doc"], corpus_79_json))

In [29]:
# corpus_79_tokens = tokenizer.tokenize(
#     corpus_79_text, 
#     update_vocab=False, # update the vocab as we tokenize
#     return_as="ids"
# )

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 146.17it/s]


In [30]:
# corpus_79_text_ids = list(map(lambda item: retriever.corpus.index(item), corpus_79_json))

In [31]:
# results, scores = retriever.retrieve(query_tokens, k=10, show_progress=False, return_as="tuple")

In [32]:
# data[data["query_id"] == 79][["doc_id"]].values.flatten().tolist()

# Calc BM25S scores

In [33]:
doc_scores = []

In [34]:
for _, row in tqdm(queries.iterrows()):
    query, query_id = row.query, row.query_id
    
    doc_ids_query = data[data["query_id"] == query_id][["doc_id"]].values.flatten().tolist()

    query_tokens = tokenizer.tokenize(
        [query], 
        update_vocab=False,
        show_progress=False,
    )
    
    results, scores = retriever.retrieve(query_tokens, k=len(corpus_json), show_progress=False)
    doc_ids_results = list(map(lambda item: item["doc_id"], results[0]))
    doc_scores += [float(scores[0, doc_ids_results.index(doc_id)]) for doc_id in doc_ids_query]

1444it [06:48,  3.53it/s]


# Calc pos statistics

In [29]:
def calc_query_token_ngram_coverage(query_tokens: list[int], doc_tokens: list[int], n: int = 1) -> int | float:
    if n > min(len(query_tokens), len(doc_tokens)):
        return 0

    query_ngram_counter = {}
    for i in range(len(query_tokens) - n + 1):
        query_ngram_counter[tuple(query_tokens[i:i + n])] = 0

    for i in range(len(doc_tokens) - n + 1):
        doc_ngram = tuple(doc_tokens[i:i + n])
        if doc_ngram in query_ngram_counter and query_ngram_counter[doc_ngram] == 0:
            query_ngram_counter[doc_ngram] = 1

    query_token_ngram_coverage = sum(query_ngram_counter.values())
    return query_token_ngram_coverage, query_token_ngram_coverage / len(query_ngram_counter)

In [30]:
calc_query_token_ngram_coverage([1, 2, 3], [1, 2, 4, 2, 4], n=3)

0.5

In [31]:
def calc_doc_length(doc_tokens: list[int]) -> int:
    return len(doc_tokens)

In [32]:
def calc_query_length(query_tokens: list[int]) -> int:
    return len(query_tokens)

In [33]:
def calc_query_token_first_occurrence(query_tokens: list[int], doc_tokens: list[int]) -> int:
    first_occurrence = len(doc_tokens)
    
    for i, doc_token in enumerate(doc_tokens):
        if doc_token in query_tokens:
            first_occurrence = i
            break
     
    if normalize:
        first_occurrence /= len(doc_tokens)

    return first_occurrence

In [34]:
calc_query_token_first_occurrence([1, 2, 3], [4, 1, 4, 2, 4])

0.2

In [35]:
def calc_idf(
    query_tokens: list[int], 
    docs_tokens: list[list[int]],
) -> dict[int, int]:
    query_tokens_counts_per_document = defaultdict(int)

    for doc_tokens in docs_tokens:
        doc_tokens_occurred_in_query = set()
        for token in doc_tokens:
            if token in query_tokens and token not in doc_tokens_occurred_in_query:
                query_tokens_counts_per_document[token] += 1
                doc_tokens_occurred_in_query.add(token)

    num_documents = len(docs_tokens)
    idf_dict = {}
    for token, doc_count in query_tokens_counts_per_document.items():
        idf_dict[token] = math.log(num_documents / (doc_count + 1)) + 1

    return idf_dict


def calc_hh_proximity(
    query_tokens: list[int], 
    doc_tokens: list[list[int]], 
    idf_dict: dict[int, int], 
    z: float = 1.75,
) -> float:
    token_positions = defaultdict(list)
    for idx, token in enumerate(doc_tokens):
        if token in query_tokens:
            token_positions[token].append(idx)

    hh_proximity = 0

    for cur_token, cur_token_positions in token_positions.items():
        cur_token_idf = idf_dict.get(cur_token, 1)
        for cur_token_position in cur_token_positions:
            for other_token in query_tokens:
                if other_token == cur_token:
                    token_weight = 0.25
                else:
                    token_weight = 1
                    
                other_positions = token_positions.get(other_token, [])
                
                if other_positions:
                    lmd = float("inf")
                    rmd = float("inf")
                    other_token_idf = idf_dict.get(other_token, 1)
                    for other_position in other_positions:
                        if other_position < cur_token_position:
                            lmd = min(lmd, cur_token_position - other_position)
                        elif other_position > cur_token_position:
                            rmd = min(rmd, other_position - cur_token_position)
                            break
                    

                    hh_proximity += token_weight * cur_token_idf * ((other_token_idf / (lmd ** z)) + (other_token_idf / (rmd ** z)))

    return math.log(1 + hh_proximity)

In [36]:
query_tokens = ["data", "science"]
doc_tokens = ["data", "is", "the", "new", "science", "of", "data", "analysis"]
docs_tokens = [
    ["data", "is", "the", "new", "science", "of", "data", "analysis"],
    ["machine", "learning", "is", "a", "field", "of", "artificial", "intelligence"],
    ["data", "science", "involves", "statistics", "and", "programming"]
]

idf_dict = calc_idf(query_tokens, docs_tokens)
print(idf_dict)
hh_proximity_score = calc_hh_proximity(query_tokens, doc_tokens, idf_dict)
print(f"HHProximity Score:", {hh_proximity_score})

{'data': 1.0, 'science': 1.0}
HHProximity Score: {0.5839557466475039}


In [37]:
doc_ids = docs["doc_id"].values.tolist()
features = []

In [38]:
for query_id, row in tqdm(data.groupby("query_id")):
    query = row["query"].values[0]

    query_tokens = tokenizer.tokenize(
        [query], 
        update_vocab=False,
        show_progress=False,
    )[0]

    docs_tokens_query = list(map(lambda doc_id: corpus_tokens[doc_ids.index(doc_id)], row["doc_id"].values))
    # we can calculate idf from scratch by code line below
    # but it is better to use fair idf scores from lucene and BM25 index
    idf_dict = calc_idf(query_tokens, docs_tokens_query)

    query_length = calc_query_length(query_tokens)

    for doc_tokens in docs_tokens_query:
        doc_length = calc_doc_length(doc_tokens)
        query_token_first_occurrence_unnormalized, query_token_first_occurrence_normalized = calc_query_token_first_occurrence(
            query_tokens, 
            doc_tokens,
        )
        query_token_last_occurrence_unnormalized, _ = calc_query_token_first_occurrence(
            query_tokens, 
            doc_tokens[::-1],
        )
        query_token_last_occurrence_unnormalized = (doc_length - 1) - query_token_last_occurrence_unnormalized
        query_token_last_occurrence_normalized = query_token_last_occurrence_unnormalized / doc_length
        span_length_unnormalized = query_token_last_occurrence_unnormalized - query_token_first_occurrence_unnormalized + 1
        span_length_normalized = span_length_unnormalized / doc_length
        hh_proximity_score = calc_hh_proximity(query_tokens, doc_tokens, idf_dict)
        
        features.append(
            [
                hh_proximity_score,
                query_length,
                doc_length,
                *calc_query_token_ngram_coverage(query_tokens, doc_tokens, n=1),
                *calc_query_token_ngram_coverage(query_tokens, doc_tokens, n=2),
                *calc_query_token_ngram_coverage(query_tokens, doc_tokens, n=3),
                query_token_first_occurrence_unnormalized,
                query_token_first_occurrence_normalized,
                query_token_last_occurrence_unnormalized,
                query_token_last_occurrence_normalized,
                span_length_unnormalized,
                span_length_normalized,
            ]
        )

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:58<00:00,  1.18s/it]


In [39]:
features = pd.DataFrame(
    data=features,
    columns=[
        "hh_proximity_score",
        "query_length",
        "doc_length",
        "query_token_unigram_coverage_unnormalized",
        "query_token_unigram_coverage_normalized",
        "query_token_bigram_coverage_unnormalized",
        "query_token_bigram_coverage_normalized",
        "query_token_trigram_coverage_unnormalized",
        "query_token_trigram_coverage_normalized",
        "query_token_first_occurrence_unnormalized",
        "query_token_first_occurrence_normalized",
        "query_token_last_occurrence_unnormalized",
        "query_token_last_occurrence_normalized",
        "span_length_unnormalized",
        "span_length_normalized",
    ],
)

In [40]:
data = pd.concat([data, features], axis=1)

In [41]:
data = data.fillna(0)

In [42]:
data.columns

Index(['query_id', 'doc_id', 'relevance', 'query', 'doc', 'bm25_score',
       'hh_proximity_score', 'query_length', 'doc_length',
       'query_token_unigram_coverage_unnormalized',
       'query_token_unigram_coverage_normalized',
       'query_token_bigram_coverage_unnormalized',
       'query_token_bigram_coverage_normalized',
       'query_token_trigram_coverage_unnormalized',
       'query_token_trigram_coverage_normalized',
       'query_token_first_occurrence_unnormalized',
       'query_token_first_occurrence_normalized',
       'query_token_last_occurrence_unnormalized',
       'query_token_last_occurrence_normalized', 'span_length_unnormalized',
       'span_length_normalized'],
      dtype='object')

In [43]:
for column_name in [
    "bm25_score",
    "hh_proximity_score",
    "query_length",
    "doc_length",
    "query_token_unigram_coverage_unnormalized",
    "query_token_bigram_coverage_unnormalized",
    "query_token_trigram_coverage_unnormalized",
    "query_token_first_occurrence_unnormalized",
    "query_token_last_occurrence_unnormalized",
    "span_length_unnormalized",
]:
    data[f"{column_name}_max"] = data.groupby("query_id")[f"{column_name}"].transform(lambda item: item.max())
    data[f"{column_name}_min"] = data.groupby("query_id")[f"{column_name}"].transform(lambda item: item.min())
    data[f"{column_name}_mean"] = data.groupby("query_id")[f"{column_name}"].transform(lambda item: item.mean())
    data[f"{column_name}_median"] = data.groupby("query_id")[f"{column_name}"].transform(lambda item: item.median())

In [44]:
data.columns

Index(['query_id', 'doc_id', 'relevance', 'query', 'doc', 'bm25_score',
       'hh_proximity_score', 'query_length', 'doc_length',
       'query_token_unigram_coverage_unnormalized',
       'query_token_unigram_coverage_normalized',
       'query_token_bigram_coverage_unnormalized',
       'query_token_bigram_coverage_normalized',
       'query_token_trigram_coverage_unnormalized',
       'query_token_trigram_coverage_normalized',
       'query_token_first_occurrence_unnormalized',
       'query_token_first_occurrence_normalized',
       'query_token_last_occurrence_unnormalized',
       'query_token_last_occurrence_normalized', 'span_length_unnormalized',
       'span_length_normalized', 'bm25_score_max', 'bm25_score_min',
       'bm25_score_mean', 'bm25_score_median', 'hh_proximity_score_max',
       'hh_proximity_score_min', 'hh_proximity_score_mean',
       'hh_proximity_score_median', 'query_length_max', 'query_length_min',
       'query_length_mean', 'query_length_median', 'doc_l

In [45]:
data[
    [
        "bm25_score",
        "hh_proximity_score",
        "query_length",
        "doc_length",
        "query_token_unigram_coverage_unnormalized",
        "query_token_unigram_coverage_normalized",
        "query_token_bigram_coverage_unnormalized",
        "query_token_bigram_coverage_normalized",
        "query_token_trigram_coverage_unnormalized",
        "query_token_trigram_coverage_normalized",
        "query_token_first_occurrence_unnormalized",
        "query_token_first_occurrence_normalized",
        "query_token_last_occurrence_unnormalized",
        "query_token_last_occurrence_normalized",
        "span_length_unnormalized",
        "span_length_normalized", 
        "bm25_score_max",
        "bm25_score_min",
        "bm25_score_mean",
        "bm25_score_median",
        "hh_proximity_score_max",
        "hh_proximity_score_min",
        "hh_proximity_score_mean",
        "hh_proximity_score_median",
        "query_length_max",
        "query_length_min",
        "query_length_mean",
        "query_length_median",
        "doc_length_max",
        "doc_length_min",
        "doc_length_mean",
        "doc_length_median",
        "query_token_unigram_coverage_unnormalized_max",
        "query_token_unigram_coverage_unnormalized_min",
        "query_token_unigram_coverage_unnormalized_mean",
        "query_token_unigram_coverage_unnormalized_median",
        "query_token_bigram_coverage_unnormalized_max",
        "query_token_bigram_coverage_unnormalized_min",
        "query_token_bigram_coverage_unnormalized_mean",
        "query_token_bigram_coverage_unnormalized_median",
        "query_token_trigram_coverage_unnormalized_max",
        "query_token_trigram_coverage_unnormalized_min",
        "query_token_trigram_coverage_unnormalized_mean",
        "query_token_trigram_coverage_unnormalized_median",
        "query_token_first_occurrence_unnormalized_max",
        "query_token_first_occurrence_unnormalized_min",
        "query_token_first_occurrence_unnormalized_mean",
        "query_token_first_occurrence_unnormalized_median",
        "query_token_last_occurrence_unnormalized_max",
        "query_token_last_occurrence_unnormalized_min",
        "query_token_last_occurrence_unnormalized_mean",
        "query_token_last_occurrence_unnormalized_median",
        "span_length_unnormalized_max",
        "span_length_unnormalized_min",
        "span_length_unnormalized_mean",
        "span_length_unnormalized_median",
        "doc_id",
        "query_id",
        "relevance",
    ]
].to_csv(f"{DATASET_PREFIX}_{DATASET_VERSION}_{STAGE}_preprocessed_light_v2.csv", index=None)