In [1]:
import pyterrier as pt
import datasets
import pandas as pd
from pathlib import Path

In [2]:
def rrf(dfs, i=1, K=100):
    scores = {}

    for df in dfs:
        for _, row in df.iterrows():
            docno = row["docno"]
            rrf_score = (1 / (i+row["rank"]))
            if docno in scores:
                scores[docno] += rrf_score
            else:
                scores[docno] = rrf_score
    # main_qid is used here to evaluate performance of merged data frame
    merged_df = pd.DataFrame(
        [{"qid": '1', "docno": k, "score": v} for k, v in sorted(scores.items(), key=lambda item: item[1], reverse=True)] 
    )
    # print(merged_df[merged_df["docno"]==244319])
    merged_df["rank"] = list(range(len(merged_df)))
    return merged_df[:K]

In [3]:
dataset = datasets.load_dataset("jonathanli/eurlex")

In [4]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['celex_id', 'title', 'text', 'eurovoc_concepts'],
        num_rows: 45000
    })
    test: Dataset({
        features: ['celex_id', 'title', 'text', 'eurovoc_concepts'],
        num_rows: 6000
    })
    validation: Dataset({
        features: ['celex_id', 'title', 'text', 'eurovoc_concepts'],
        num_rows: 6000
    })
})


In [5]:
ds1 = dataset['train'].to_pandas()
ds2 = dataset['test'].to_pandas()
ds3 = dataset['validation'].to_pandas()

In [6]:
ds4 = pd.concat([ds1, ds2], axis=0)
pd_ds = pd.concat([ds4, ds3], axis=0)


In [7]:
print(pd_ds)

        celex_id                                              title  \
0     32014R0727  Commission Implementing Regulation (EU) No 727...   
1     31975R2481  Regulation (EEC) No 2481/75 of the Council of ...   
2     32010D0008  2010/8/EU, Euratom: Commission Decision of 22 ...   
3     31982D0211  82/211/EEC: Commission Decision of 17 March 19...   
4     31996D0084  96/84/Euratom, ECSC, EC: Commission Decision o...   
...          ...                                                ...   
5995  32007R0522  Commission Regulation (EC) No 522/2007 of 11 M...   
5996  32005R0245  Commission Regulation (EC) No 245/2005 of 11 F...   
5997  31995D0380  95/380/EC: Commission Decision of 18 September...   
5998  31989R1200  Commission Regulation (EEC) No 1200/89 of 3 Ma...   
5999  32015D0205  Commission Implementing Decision (EU) 2015/205...   

                                                   text  \
0     1.7.2014 EN Official Journal of the European U...   
1     REGULATION (EEC) No 248

In [8]:
index_ref = None
cache_dir = Path("cache/")
index_dir = cache_dir / "indices" / "eur_lex"

pd_ds_rename = pd_ds.rename(columns={'celex_id': 'docno'}, inplace=False)

pd_ds_dict = pd_ds_rename.to_dict(orient='records')

try:
    index_ref = pt.IndexFactory.of(str(index_dir.absolute()))
except:
    indexer = pt.index.IterDictIndexer(str(index_dir.absolute()))
    index_ref = indexer.index(
        pd_ds_dict
    )   

Java started (triggered by IndexFactory.of) and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]


In [9]:
index_ref_title = None
cache_dir = Path("cache/")
index_dir2 = cache_dir / "indices" / "eur_lex_titles"

pd_ds_rename = pd_ds.rename(columns={'celex_id': 'docno', 'text':'not_text', 'title':'text'}, inplace=False)

pd_ds_dict = pd_ds_rename.to_dict(orient='records')

try:
    index_ref_title = pt.IndexFactory.of(str(index_dir2.absolute()))
except:
    indexer_title = pt.index.IterDictIndexer(str(index_dir2.absolute()))
    index_ref_title = indexer_title.index(
        pd_ds_dict
    )   

In [10]:
bm25_text = pt.terrier.Retriever(index_ref, wmodel="BM25")
bm25_title = pt.terrier.Retriever(index_ref_title, wmodel="BM25")

In [11]:
def get_text(row):
    return list(pd_ds[pd_ds['celex_id']==row['docno']]['text'])[0]

def get_title(row):
    return list(pd_ds[pd_ds['celex_id']==row['docno']]['title'])[0]

In [12]:
retr_text = bm25_text.search('Journal')
retr_title = bm25_title.search('Journal')
results = rrf([retr_text, retr_title], K=5)
print(results)

  qid       docno     score  rank
0   1  32013R0216  2.000000     0
1   1  31988L0665  0.750000     1
2   1  32012R0623  0.500000     2
3   1  32011D0479  0.333333     3
4   1  32006D0178  0.333333     4


In [13]:
results['title'] = results.apply(get_title, axis=1, raw=False)
results['text'] = results.apply(get_text, axis=1, raw=False)
print(results)

  qid       docno     score  rank  \
0   1  32013R0216  2.000000     0   
1   1  31988L0665  0.750000     1   
2   1  32012R0623  0.500000     2   
3   1  32011D0479  0.333333     3   
4   1  32006D0178  0.333333     4   

                                               title  \
0  Council Regulation (EU) No 216/2013 of 7 March...   
1  Council Directive 88/665/EEC of 21 December 19...   
2  Commission Regulation (EU) No 623/2012 of 11 J...   
3  2011/479/: Commission Decision of 27 July 2011...   
4  2006/178/EC: Commission Decision of  27 Februa...   

                                                text  
0  13.3.2013 EN Official Journal of the European ...  
1  COUNCIL DIRECTIVE of 21 December 1988 amending...  
2  12.7.2012 EN Official Journal of the European ...  
3  29.7.2011 EN Official Journal of the European ...  
4  4.3.2006 EN Official Journal of the European U...  
