# Dataset Info

## Metadata for vaswani dataset

```json
{
  "docs": {
    "count": 11429,
    "fields": {
      "doc_id": {
        "max_len": 5,
        "common_prefix": ""
      }
    }
  },
  "queries": {
    "count": 93
  },
  "qrels": {
    "count": 2083,
    "fields": {
      "relevance": {
        "counts_by_value": {
          "1": 2083
        }
      }
    }
  }
}
```


# Implementation

In [49]:
import pyterrier as pt

if not pt.java.started():
    pt.init()


In [50]:
dataset = pt.get_dataset('irds:vaswani')

In [51]:
import os

raw_index_path = './indices/vaswani_positional_raw'
index_path = './indices/vaswani_positional'

def create_index(index_path, stemmer=None, stop_words=None):
    indexer = pt.IterDictIndexer(
        index_path=index_path, 
        blocks=True, # to save positional information
        overwrite=True, 
        text_attrs=['text'], 
        meta_reverse=['docno'], 
        meta={'docno': 5, 'text': 4096},
        verbose=1,
        stemmer=stemmer,
        stopwords=stop_words,
        tokeniser=pt.TerrierTokeniser.english,
        type=pt.IndexingType.CLASSIC
    )

    index_ref = indexer.index(dataset.get_corpus_iter()) # type: ignore
    return pt.IndexFactory.of(index_ref) # type: ignore

# if indices are already created, do not create again
if not os.path.exists(os.path.join(raw_index_path, "data.properties")):
    raw_index = create_index(
        index_path=raw_index_path
    )
else:
    raw_index = pt.IndexFactory.of(raw_index_path)

if not os.path.exists(os.path.join(index_path, "data.properties")):
    index = create_index(
        index_path=index_path,
        stemmer=pt.TerrierStemmer.porter,
        stop_words=pt.TerrierStopwords.terrier
    )
else:
    index = pt.IndexFactory.of(index_path)

In [52]:
print("Raw index statistics:")
print(raw_index.getCollectionStatistics().toString())
print("Preprocessed index statistics:")
print(index.getCollectionStatistics().toString())

Raw index statistics:
Number of documents: 11429
Number of terms: 12188
Number of postings: 351589
Number of fields: 0
Number of tokens: 479162
Field names: []
Positions:   true

Preprocessed index statistics:
Number of documents: 11429
Number of terms: 7756
Number of postings: 224573
Number of fields: 0
Number of tokens: 271581
Field names: []
Positions:   true


In [53]:
raw_bm25 = pt.terrier.Retriever(raw_index, wmodel="BM25", controls={"qe": "off", "proximity": "on"}, metadata=["docno", "text"]) # type: ignore
bm25 = pt.terrier.Retriever(index, wmodel="BM25", controls={"qe": "off", "proximity": "on"}, metadata=["docno", "text"]) # type: ignore

In [54]:
x = bm25.search(query="electronic analogue computer") # type: ignore
x[:10]

Unnamed: 0,qid,docid,docno,text,rank,score,query
0,1,5139,5140,an introduction to electronic analogue computers,0,19.413575,electronic analogue computer
1,1,139,140,the simulation of equations with analogue comp...,1,18.968913,electronic analogue computer
2,1,2932,2933,electronic computers the application of analo...,2,18.590034,electronic analogue computer
3,1,3138,3139,an error analysis of electronic analogue compu...,3,18.450685,electronic analogue computer
4,1,5131,5132,high speed electronic analogue computing techn...,4,18.450685,electronic analogue computer
5,1,5831,5832,electronic analogue computing a survey of mod...,5,18.43094,electronic analogue computer
6,1,5016,5017,principles and application of electronic analo...,6,18.101647,electronic analogue computer
7,1,1156,1157,electronic computers annual review of interna...,7,17.887571,electronic analogue computer
8,1,3137,3138,a multipurpose electronic switch for analogue ...,8,17.173041,electronic analogue computer
9,1,5021,5022,the hyperbolic field tube an electron beam tub...,9,17.173041,electronic analogue computer


In [55]:
# PRF
bo1 = pt.rewrite.Bo1QueryExpansion(index_like=index, fb_terms=20, fb_docs=3)
kl = pt.rewrite.KLQueryExpansion(index_like=index, fb_terms=20, fb_docs=3)
rm3 = pt.rewrite.RM3(index_like=index, fb_terms=20, fb_docs=3) # type: ignore

pipeline = bm25 >> bo1 >> bm25
pipeline_2 = bm25 >> kl >> bm25
pipeline_3 = bm25 >> rm3 >> bm25

In [56]:
docno_text_dict = {d['docno']: d['text'] for d in dataset.get_corpus_iter()}

vaswani documents: 100%|██████████| 11429/11429 [00:00<00:00, 557481.28it/s]


In [57]:
from sentence_transformers import CrossEncoder
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()

class CrossEncoderReranker(pt.Transformer):
    def __init__(self, model_name='cross-encoder/ms-marco-MiniLM-L-6-v2', top_k=50, text_field='text'):
        self.model = CrossEncoder(model_name)
        self.top_k = top_k
        self.text_field = text_field
    
    def transform(self, res):
        if self.text_field not in res.columns:
            raise ValueError(f"Document text field '{self.text_field}' not found in results")
        
        reranked_dfs = []
        for qid, group in res.groupby('qid'):
            query = group['query'].iloc[0]
            doc_nos = group["docno"].tolist()
            docs = [docno_text_dict.get(doc_no, "") for doc_no in doc_nos]
            group = group.copy()
            group[self.text_field] = docs

            top_k = min(self.top_k, len(docs))
            limited_group = group.head(top_k)
            pairs = [(query, doc) for doc in limited_group[self.text_field]]

            scores = self.model.predict(pairs)

            full_scores = np.zeros(len(group))
            full_scores[:len(scores)] = scores

            group['crossencoder_score'] = full_scores 
            group['score_norm'] = scaler.fit_transform(group[['score']])
            group['crossencoder_score_norm'] = scaler.fit_transform(full_scores.reshape(-1,1))

            group['combined_score'] = 0.4 * group['crossencoder_score_norm'] + 0.6 * group['score_norm']

            reranked = group.sort_values(['combined_score', 'rank'], ascending=[False, True])
            reranked.reset_index(drop=True, inplace=True)
            reranked['old_rank'] = reranked['rank']
            reranked['rank'] = reranked.index

            reranked_dfs.append(reranked)

        final = pd.concat(reranked_dfs) if reranked_dfs else pd.DataFrame()
        return final
    
ce_pipeline = bm25 >> CrossEncoderReranker(text_field="text", top_k=100)


In [58]:
from collections import defaultdict

def precision_at_k(results, qrels, k=10):
    qrels_dict = qrels.groupby('qid')['docno'].apply(set).to_dict()
    precisions = []

    for qid, group in results.groupby('qid'):
        relevant = qrels_dict.get(qid, set())
        top_k_docs = group.head(k)['docno'].tolist()
        num_relevant = sum([1 for docno in top_k_docs if docno in relevant])
        precisions.append(num_relevant / k)
    
    return sum(precisions) / len(precisions)


def recall_at_k(results, qrels, k=10):
    qrels_dict = qrels.groupby('qid')['docno'].apply(set).to_dict()
    recalls = []

    for qid, group in results.groupby('qid'):
        relevant = qrels_dict.get(qid, set())
        if not relevant:
            continue
        top_k_docs = group.head(k)['docno'].tolist()
        num_relevant = sum([1 for docno in top_k_docs if docno in relevant])
        recalls.append(num_relevant / len(relevant))
    
    return sum(recalls) / len(recalls)


def mean_average_precision(results, qrels):
    qrels_dict = qrels.groupby('qid')['docno'].apply(set).to_dict()
    average_precisions = []

    for qid, group in results.groupby('qid'):
        relevant = qrels_dict.get(qid, set())
        if not relevant:
            continue

        hits = 0
        precisions = []

        for i, docno in enumerate(group['docno'].tolist(), start=1):
            if docno in relevant:
                hits += 1
                precisions.append(hits / i)

        if precisions:
            average_precisions.append(sum(precisions) / len(relevant))
        else:
            average_precisions.append(0.0)

    return sum(average_precisions) / len(average_precisions)


In [59]:
def calculate_metrics(model, name):
    results_df = model.transform(dataset.get_topics())

    precision = precision_at_k(results_df, dataset.get_qrels(), k=5)
    precision_10 = precision_at_k(results_df, dataset.get_qrels(), k=10)
    recall = recall_at_k(results_df, dataset.get_qrels(), k=5)
    recall_10 = recall_at_k(results_df, dataset.get_qrels(), k=10)
    map_score = mean_average_precision(results_df, dataset.get_qrels())
    
    results = [name, precision, precision_10, recall, recall_10, map_score]
    return results

In [60]:
raw_bm25_results = calculate_metrics(raw_bm25, "Raw BM25")

In [61]:
bm25_results = calculate_metrics(bm25, "BM25")


In [62]:
pipeline_results = calculate_metrics(pipeline, "BM25 + Bo1")

In [63]:
pipeline_2_results = calculate_metrics(pipeline_2, "BM25 + KL")

In [64]:
pipeline_3_results = calculate_metrics(pipeline_3, "BM25 + RM3")

In [65]:
ce_bert_results = calculate_metrics(ce_pipeline, "BM25 + Cross Encoder")

In [66]:
table = [["Model", "Precision@5", "Precision@10",  "Recall@5", "Recall@10", "MAP"], raw_bm25_results, bm25_results,
        pipeline_results, pipeline_2_results, pipeline_3_results, ce_bert_results]
df = pd.DataFrame(table[1:], columns=table[0])
df

Unnamed: 0,Model,Precision@5,Precision@10,Recall@5,Recall@10,MAP
0,Raw BM25,0.236559,0.195699,0.086218,0.126533,0.143478
1,BM25,0.460215,0.351613,0.162592,0.217617,0.296513
2,BM25 + Bo1,0.470968,0.370968,0.164555,0.222465,0.305298
3,BM25 + KL,0.468817,0.365591,0.164229,0.218986,0.302741
4,BM25 + RM3,0.475269,0.370968,0.16486,0.224261,0.305449
5,BM25 + Cross Encoder,0.492473,0.401075,0.169478,0.241,0.297633
