In [1]:
import os
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from tqdm.auto import tqdm
import huggingface_hub as hf
from dotenv import load_dotenv
import matplotlib.pyplot as plt
from typing import List, Dict, Union, Tuple
from transformers import AutoTokenizer, AutoModel

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 256)

plt.style.use('seaborn-v0_8')
load_dotenv()
hf.login(os.environ["HF_TOKEN"])
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
print("CUDA_VISIBLE_DEVICES:", os.environ["CUDA_VISIBLE_DEVICES"], "HF_HOME:", os.environ["HF_HOME"])

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/mohsenfayyaz/.cache/huggingface/token
Login successful
CUDA_VISIBLE_DEVICES: 2 HF_HOME: /local1/mohsenfayyaz/.hfcache/


In [2]:
paths = hf.HfFileSystem().ls("hf://datasets/Retriever-Contextualization/datasets/re-docred/", detail=False)
paths = [p for p in paths if ".pkl.gz" in p]
results = []
for path in tqdm(paths):
    df = pd.read_pickle(f"hf://{path}")
    try:
        results.append({
            "query_model": df.attrs["query_model"],
            "context_model": df.attrs["context_model"],
            "corpus_size": df.attrs["corpus_size"],
            "pooling": df.attrs["pooling"],
            "eval_ndcg": df.attrs["eval"]["ndcg"],
            "eval_map": df.attrs["eval"]["map"],
            "eval_recall": df.attrs["eval"]["recall"],
            "eval_precision": df.attrs["eval"]["precision"],
        })
    except Exception as e:
        print(e)
        print(path)
pd.DataFrame(results)

  0%|          | 0/10 [00:00<?, ?it/s]

'query_model'
datasets/Retriever-Contextualization/datasets/re-docred/facebook--contriever-msmarco_corpus105925_sentencebert.pkl.gz


Unnamed: 0,query_model,context_model,corpus_size,pooling,eval_ndcg,eval_map,eval_recall,eval_precision
0,OpenMatch/cocodr-base-msmarco,OpenMatch/cocodr-base-msmarco,105925,cls,"{'NDCG@1': 0.42636, 'NDCG@3': 0.4586, 'NDCG@5': 0.46749, 'NDCG@10': 0.47563, 'NDCG@100': 0.49961, 'NDCG@1000': 0.51485}","{'MAP@1': 0.42636, 'MAP@3': 0.45091, 'MAP@5': 0.45583, 'MAP@10': 0.45918, 'MAP@100': 0.46368, 'MAP@1000': 0.46416}","{'Recall@1': 0.42636, 'Recall@3': 0.48075, 'Recall@5': 0.50237, 'Recall@10': 0.52762, 'Recall@100': 0.64477, 'Recall@1000': 0.77001}","{'P@1': 0.42636, 'P@3': 0.16025, 'P@5': 0.10047, 'P@10': 0.05276, 'P@100': 0.00645, 'P@1000': 0.00077}"
1,OpenMatch/cocodr-large-msmarco,OpenMatch/cocodr-large-msmarco,105925,cls,"{'NDCG@1': 0.44045, 'NDCG@3': 0.47401, 'NDCG@5': 0.48295, 'NDCG@10': 0.49287, 'NDCG@100': 0.51603, 'NDCG@1000': 0.53244}","{'MAP@1': 0.44045, 'MAP@3': 0.46592, 'MAP@5': 0.47087, 'MAP@10': 0.47493, 'MAP@100': 0.47932, 'MAP@1000': 0.47985}","{'Recall@1': 0.44045, 'Recall@3': 0.49735, 'Recall@5': 0.51911, 'Recall@10': 0.54993, 'Recall@100': 0.66304, 'Recall@1000': 0.79679}","{'P@1': 0.44045, 'P@3': 0.16578, 'P@5': 0.10382, 'P@10': 0.05499, 'P@100': 0.00663, 'P@1000': 0.0008}"
2,Shitao/RetroMAE_MSMARCO,Shitao/RetroMAE_MSMARCO,105925,cls,"{'NDCG@1': 0.17211, 'NDCG@3': 0.20955, 'NDCG@5': 0.22407, 'NDCG@10': 0.23791, 'NDCG@100': 0.27226, 'NDCG@1000': 0.29704}","{'MAP@1': 0.17211, 'MAP@3': 0.20028, 'MAP@5': 0.20833, 'MAP@10': 0.21409, 'MAP@100': 0.22029, 'MAP@1000': 0.22107}","{'Recall@1': 0.17211, 'Recall@3': 0.2364, 'Recall@5': 0.27169, 'Recall@10': 0.31423, 'Recall@100': 0.48494, 'Recall@1000': 0.68926}","{'P@1': 0.17211, 'P@3': 0.0788, 'P@5': 0.05434, 'P@10': 0.03142, 'P@100': 0.00485, 'P@1000': 0.00069}"
3,Shitao/RetroMAE_MSMARCO_finetune,Shitao/RetroMAE_MSMARCO_finetune,105925,cls,"{'NDCG@1': 0.4357, 'NDCG@3': 0.47328, 'NDCG@5': 0.48231, 'NDCG@10': 0.49238, 'NDCG@100': 0.5148, 'NDCG@1000': 0.53246}","{'MAP@1': 0.4357, 'MAP@3': 0.46409, 'MAP@5': 0.4691, 'MAP@10': 0.47324, 'MAP@100': 0.47734, 'MAP@1000': 0.4779}","{'Recall@1': 0.4357, 'Recall@3': 0.49986, 'Recall@5': 0.52176, 'Recall@10': 0.553, 'Recall@100': 0.66346, 'Recall@1000': 0.80921}","{'P@1': 0.4357, 'P@3': 0.16662, 'P@5': 0.10435, 'P@10': 0.0553, 'P@100': 0.00663, 'P@1000': 0.00081}"
4,Shitao/RetroMAE,Shitao/RetroMAE,105925,cls,"{'NDCG@1': 0.21771, 'NDCG@3': 0.26187, 'NDCG@5': 0.27656, 'NDCG@10': 0.29206, 'NDCG@100': 0.32304, 'NDCG@1000': 0.34523}","{'MAP@1': 0.21771, 'MAP@3': 0.25105, 'MAP@5': 0.25923, 'MAP@10': 0.26567, 'MAP@100': 0.27142, 'MAP@1000': 0.27213}","{'Recall@1': 0.21771, 'Recall@3': 0.29317, 'Recall@5': 0.32873, 'Recall@10': 0.37643, 'Recall@100': 0.52887, 'Recall@1000': 0.71074}","{'P@1': 0.21771, 'P@3': 0.09772, 'P@5': 0.06575, 'P@10': 0.03764, 'P@100': 0.00529, 'P@1000': 0.00071}"
5,facebook/contriever-msmarco,facebook/contriever-msmarco,105925,avg,"{'NDCG@1': 0.46109, 'NDCG@3': 0.49878, 'NDCG@5': 0.50989, 'NDCG@10': 0.5215, 'NDCG@100': 0.54809, 'NDCG@1000': 0.56627}","{'MAP@1': 0.46109, 'MAP@3': 0.48954, 'MAP@5': 0.49568, 'MAP@10': 0.50046, 'MAP@100': 0.50531, 'MAP@1000': 0.50593}","{'Recall@1': 0.46109, 'Recall@3': 0.52552, 'Recall@5': 0.55258, 'Recall@10': 0.58856, 'Recall@100': 0.72008, 'Recall@1000': 0.86709}","{'P@1': 0.46109, 'P@3': 0.17517, 'P@5': 0.11052, 'P@10': 0.05886, 'P@100': 0.0072, 'P@1000': 0.00087}"
6,facebook/contriever,facebook/contriever,105925,avg,"{'NDCG@1': 0.41074, 'NDCG@3': 0.46846, 'NDCG@5': 0.48409, 'NDCG@10': 0.49891, 'NDCG@100': 0.53099, 'NDCG@1000': 0.54827}","{'MAP@1': 0.41074, 'MAP@3': 0.45463, 'MAP@5': 0.46329, 'MAP@10': 0.46937, 'MAP@100': 0.47543, 'MAP@1000': 0.476}","{'Recall@1': 0.41074, 'Recall@3': 0.50837, 'Recall@5': 0.5463, 'Recall@10': 0.59233, 'Recall@100': 0.74868, 'Recall@1000': 0.88926}","{'P@1': 0.41074, 'P@3': 0.16946, 'P@5': 0.10926, 'P@10': 0.05923, 'P@100': 0.00749, 'P@1000': 0.00089}"
7,facebook/dragon-plus-query-encoder,facebook/dragon-plus-context-encoder,105925,cls,"{'NDCG@1': 0.47685, 'NDCG@3': 0.52523, 'NDCG@5': 0.53646, 'NDCG@10': 0.54955, 'NDCG@100': 0.58002, 'NDCG@1000': 0.59556}","{'MAP@1': 0.47685, 'MAP@3': 0.51341, 'MAP@5': 0.51959, 'MAP@10': 0.52496, 'MAP@100': 0.53058, 'MAP@1000': 0.53109}","{'Recall@1': 0.47685, 'Recall@3': 0.55941, 'Recall@5': 0.58689, 'Recall@10': 0.62748, 'Recall@100': 0.77741, 'Recall@1000': 0.90349}","{'P@1': 0.47685, 'P@3': 0.18647, 'P@5': 0.11738, 'P@10': 0.06275, 'P@100': 0.00777, 'P@1000': 0.0009}"
8,facebook/dragon-roberta-query-encoder,facebook/dragon-roberta-context-encoder,105925,cls,"{'NDCG@1': 0.46318, 'NDCG@3': 0.50342, 'NDCG@5': 0.51355, 'NDCG@10': 0.52556, 'NDCG@100': 0.55455, 'NDCG@1000': 0.57139}","{'MAP@1': 0.46318, 'MAP@3': 0.49351, 'MAP@5': 0.49912, 'MAP@10': 0.50406, 'MAP@100': 0.50941, 'MAP@1000': 0.50995}","{'Recall@1': 0.46318, 'Recall@3': 0.53208, 'Recall@5': 0.55676, 'Recall@10': 0.594, 'Recall@100': 0.73668, 'Recall@1000': 0.87476}","{'P@1': 0.46318, 'P@3': 0.17736, 'P@5': 0.11135, 'P@10': 0.0594, 'P@100': 0.00737, 'P@1000': 0.00087}"


In [4]:
pd.set_option('display.max_colwidth', 60)
df = pd.DataFrame(results)
df["Recall@10"] = df["eval_recall"].apply(lambda x: x["Recall@10"])
df["nDCG@10"] = df["eval_ndcg"].apply(lambda x: x["NDCG@10"])
df = df.sort_values("nDCG@10", ascending=False)
df

Unnamed: 0,query_model,context_model,corpus_size,pooling,eval_ndcg,eval_map,eval_recall,eval_precision,Recall@10,nDCG@10
7,facebook/dragon-plus-query-encoder,facebook/dragon-plus-context-encoder,105925,cls,"{'NDCG@1': 0.47685, 'NDCG@3': 0.52523, 'NDCG@5': 0.53646...","{'MAP@1': 0.47685, 'MAP@3': 0.51341, 'MAP@5': 0.51959, '...","{'Recall@1': 0.47685, 'Recall@3': 0.55941, 'Recall@5': 0...","{'P@1': 0.47685, 'P@3': 0.18647, 'P@5': 0.11738, 'P@10':...",0.62748,0.54955
8,facebook/dragon-roberta-query-encoder,facebook/dragon-roberta-context-encoder,105925,cls,"{'NDCG@1': 0.46318, 'NDCG@3': 0.50342, 'NDCG@5': 0.51355...","{'MAP@1': 0.46318, 'MAP@3': 0.49351, 'MAP@5': 0.49912, '...","{'Recall@1': 0.46318, 'Recall@3': 0.53208, 'Recall@5': 0...","{'P@1': 0.46318, 'P@3': 0.17736, 'P@5': 0.11135, 'P@10':...",0.594,0.52556
5,facebook/contriever-msmarco,facebook/contriever-msmarco,105925,avg,"{'NDCG@1': 0.46109, 'NDCG@3': 0.49878, 'NDCG@5': 0.50989...","{'MAP@1': 0.46109, 'MAP@3': 0.48954, 'MAP@5': 0.49568, '...","{'Recall@1': 0.46109, 'Recall@3': 0.52552, 'Recall@5': 0...","{'P@1': 0.46109, 'P@3': 0.17517, 'P@5': 0.11052, 'P@10':...",0.58856,0.5215
6,facebook/contriever,facebook/contriever,105925,avg,"{'NDCG@1': 0.41074, 'NDCG@3': 0.46846, 'NDCG@5': 0.48409...","{'MAP@1': 0.41074, 'MAP@3': 0.45463, 'MAP@5': 0.46329, '...","{'Recall@1': 0.41074, 'Recall@3': 0.50837, 'Recall@5': 0...","{'P@1': 0.41074, 'P@3': 0.16946, 'P@5': 0.10926, 'P@10':...",0.59233,0.49891
1,OpenMatch/cocodr-large-msmarco,OpenMatch/cocodr-large-msmarco,105925,cls,"{'NDCG@1': 0.44045, 'NDCG@3': 0.47401, 'NDCG@5': 0.48295...","{'MAP@1': 0.44045, 'MAP@3': 0.46592, 'MAP@5': 0.47087, '...","{'Recall@1': 0.44045, 'Recall@3': 0.49735, 'Recall@5': 0...","{'P@1': 0.44045, 'P@3': 0.16578, 'P@5': 0.10382, 'P@10':...",0.54993,0.49287
3,Shitao/RetroMAE_MSMARCO_finetune,Shitao/RetroMAE_MSMARCO_finetune,105925,cls,"{'NDCG@1': 0.4357, 'NDCG@3': 0.47328, 'NDCG@5': 0.48231,...","{'MAP@1': 0.4357, 'MAP@3': 0.46409, 'MAP@5': 0.4691, 'MA...","{'Recall@1': 0.4357, 'Recall@3': 0.49986, 'Recall@5': 0....","{'P@1': 0.4357, 'P@3': 0.16662, 'P@5': 0.10435, 'P@10': ...",0.553,0.49238
0,OpenMatch/cocodr-base-msmarco,OpenMatch/cocodr-base-msmarco,105925,cls,"{'NDCG@1': 0.42636, 'NDCG@3': 0.4586, 'NDCG@5': 0.46749,...","{'MAP@1': 0.42636, 'MAP@3': 0.45091, 'MAP@5': 0.45583, '...","{'Recall@1': 0.42636, 'Recall@3': 0.48075, 'Recall@5': 0...","{'P@1': 0.42636, 'P@3': 0.16025, 'P@5': 0.10047, 'P@10':...",0.52762,0.47563
4,Shitao/RetroMAE,Shitao/RetroMAE,105925,cls,"{'NDCG@1': 0.21771, 'NDCG@3': 0.26187, 'NDCG@5': 0.27656...","{'MAP@1': 0.21771, 'MAP@3': 0.25105, 'MAP@5': 0.25923, '...","{'Recall@1': 0.21771, 'Recall@3': 0.29317, 'Recall@5': 0...","{'P@1': 0.21771, 'P@3': 0.09772, 'P@5': 0.06575, 'P@10':...",0.37643,0.29206
2,Shitao/RetroMAE_MSMARCO,Shitao/RetroMAE_MSMARCO,105925,cls,"{'NDCG@1': 0.17211, 'NDCG@3': 0.20955, 'NDCG@5': 0.22407...","{'MAP@1': 0.17211, 'MAP@3': 0.20028, 'MAP@5': 0.20833, '...","{'Recall@1': 0.17211, 'Recall@3': 0.2364, 'Recall@5': 0....","{'P@1': 0.17211, 'P@3': 0.0788, 'P@5': 0.05434, 'P@10': ...",0.31423,0.23791


In [8]:
df = df[['query_model', 'pooling', 'nDCG@10', 'Recall@10']]
# df = df.iloc[:6]

def clean_table(table_str):
    mappings = {
        "attention": "Attention",
    }
    model_mappings = {
        "OpenMatch/cocodr-base-msmarco": ("COCO-DR", "Base MSMARCO"),
        "Shitao/RetroMAE\_MSMARCO\_finetune": ("RetroMAE", "MSMARCO FT"),
        "Shitao/RetroMAE\_MSMARCO": ("RetroMAE", "MSMARCO"),
        "Shitao/RetroMAE": ("RetroMAE", ""),
        "facebook/contriever-msmarco": ("Contriever", "MSMARCO"),
        "facebook/contriever": ("Contriever", ""),
        "facebook/dragon-plus-query-encoder": ("Dragon+", ""),
        "facebook/dragon-roberta-query-encoder": ("Dragon RoBERTa", ""),
    }
    back = "\\"
    raw_mappings = {
        # r"Method & Selection &  &  &  &  &  &  &  &  &  &  \\": "",
        "query\_model": "Model",
        "llllrrrrrr": r"p{1.2cm}p{1.2cm}p{1.2cm}p{1.2cm}rrrrrr",
        r"\cline{1-10} \cline{2-10} \cline{3-10}": "\\midrule",
    }
    for k, v in mappings.items():
        key = k.replace('_', '\\_')
        table_str = table_str.replace("\\textbf{" + key + "}", "\\textsc{" + v + "}")
    for k, v in model_mappings.items():
        key = k
        value = v[0] + r" " + v[1]
        table_str = table_str.replace(key, value)
    for k, v in raw_mappings.items():
        table_str = table_str.replace(k, v)
    return table_str
print(clean_table(df.round(3).to_latex(float_format="%.3f", bold_rows=True, index=False).replace("_", "\_")))


df

\begin{tabular}{llrr}
\toprule
Model & pooling & nDCG@10 & Recall@10 \\
\midrule
Dragon+  & cls & 0.550 & 0.627 \\
Dragon RoBERTa  & cls & 0.526 & 0.594 \\
Contriever MSMARCO & avg & 0.522 & 0.589 \\
Contriever  & avg & 0.499 & 0.592 \\
OpenMatch/cocodr-large-msmarco & cls & 0.493 & 0.550 \\
RetroMAE MSMARCO FT & cls & 0.492 & 0.553 \\
COCO-DR Base MSMARCO & cls & 0.476 & 0.528 \\
RetroMAE  & cls & 0.292 & 0.376 \\
RetroMAE MSMARCO & cls & 0.238 & 0.314 \\
\bottomrule
\end{tabular}



  "Shitao/RetroMAE\_MSMARCO\_finetune": ("RetroMAE", "MSMARCO FT"),
  "Shitao/RetroMAE\_MSMARCO": ("RetroMAE", "MSMARCO"),
  "query\_model": "Model",
  print(clean_table(df.round(3).to_latex(float_format="%.3f", bold_rows=True, index=False).replace("_", "\_")))


Unnamed: 0,query_model,pooling,nDCG@10,Recall@10
7,facebook/dragon-plus-query-encoder,cls,0.54955,0.62748
8,facebook/dragon-roberta-query-encoder,cls,0.52556,0.594
5,facebook/contriever-msmarco,avg,0.5215,0.58856
6,facebook/contriever,avg,0.49891,0.59233
1,OpenMatch/cocodr-large-msmarco,cls,0.49287,0.54993
3,Shitao/RetroMAE_MSMARCO_finetune,cls,0.49238,0.553
0,OpenMatch/cocodr-base-msmarco,cls,0.47563,0.52762
4,Shitao/RetroMAE,cls,0.29206,0.37643
2,Shitao/RetroMAE_MSMARCO,cls,0.23791,0.31423
