In [1]:
import os
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from tqdm.auto import tqdm
import huggingface_hub as hf
from dotenv import load_dotenv
import matplotlib.pyplot as plt
from typing import List, Dict, Union, Tuple
from transformers import AutoTokenizer, AutoModel

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 256)

plt.style.use('seaborn-v0_8')
load_dotenv()
hf.login(os.environ["HF_TOKEN"])
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
print("CUDA_VISIBLE_DEVICES:", os.environ["CUDA_VISIBLE_DEVICES"], "HF_HOME:", os.environ["HF_HOME"])

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/mohsenfayyaz/.cache/huggingface/token
Login successful
CUDA_VISIBLE_DEVICES: 2 HF_HOME: /local1/mohsenfayyaz/.hfcache/


In [4]:
DATASET = "nq"  # "nq" or "re-docred"
paths = hf.HfFileSystem().ls(f"hf://datasets/Retriever-Contextualization/datasets/{DATASET}/", detail=False)
paths = [p for p in paths if ".pkl.gz" in p]
results = []
for path in tqdm(paths):
    df = pd.read_pickle(f"hf://{path}")
    try:
        results.append({
            "query_model": df.attrs["query_model"],
            "context_model": df.attrs["context_model"],
            "corpus_size": df.attrs["corpus_size"],
            "pooling": df.attrs["pooling"],
            "eval_ndcg": df.attrs["eval"]["ndcg"],
            "eval_map": df.attrs["eval"]["map"],
            "eval_recall": df.attrs["eval"]["recall"],
            "eval_precision": df.attrs["eval"]["precision"],
        })
    except Exception as e:
        print(e)
        print(path)
pd.DataFrame(results)

  0%|          | 0/6 [00:00<?, ?it/s]

Unnamed: 0,query_model,context_model,corpus_size,pooling,eval_ndcg,eval_map,eval_recall,eval_precision
0,OpenMatch/cocodr-base-msmarco,OpenMatch/cocodr-base-msmarco,2681468,cls,"{'NDCG@1': 0.32184, 'NDCG@3': 0.42313, 'NDCG@5': 0.46486...","{'MAP@1': 0.28438, 'MAP@3': 0.38558, 'MAP@5': 0.41034, '...","{'Recall@1': 0.28438, 'Recall@3': 0.49998, 'Recall@5': 0...","{'P@1': 0.32184, 'P@3': 0.1939, 'P@5': 0.1405, 'P@10': 0..."
1,Shitao/RetroMAE_MSMARCO_finetune,Shitao/RetroMAE_MSMARCO_finetune,2681468,cls,"{'NDCG@1': 0.30881, 'NDCG@3': 0.40739, 'NDCG@5': 0.44582...","{'MAP@1': 0.27815, 'MAP@3': 0.37295, 'MAP@5': 0.39585, '...","{'Recall@1': 0.27815, 'Recall@3': 0.4782, 'Recall@5': 0....","{'P@1': 0.30881, 'P@3': 0.18299, 'P@5': 0.13204, 'P@10':..."
2,facebook/contriever-msmarco,facebook/contriever-msmarco,2681468,avg,"{'NDCG@1': 0.30736, 'NDCG@3': 0.41327, 'NDCG@5': 0.45537...","{'MAP@1': 0.27141, 'MAP@3': 0.37456, 'MAP@5': 0.3995, 'M...","{'Recall@1': 0.27141, 'Recall@3': 0.49355, 'Recall@5': 0...","{'P@1': 0.30736, 'P@3': 0.19119, 'P@5': 0.13853, 'P@10':..."
3,facebook/contriever,facebook/contriever,2681468,avg,"{'NDCG@1': 0.12601, 'NDCG@3': 0.18426, 'NDCG@5': 0.21274...","{'MAP@1': 0.11061, 'MAP@3': 0.16281, 'MAP@5': 0.17905, '...","{'Recall@1': 0.11061, 'Recall@3': 0.22663, 'Recall@5': 0...","{'P@1': 0.12601, 'P@3': 0.08864, 'P@5': 0.069, 'P@10': 0..."
4,facebook/dragon-plus-query-encoder,facebook/dragon-plus-context-encoder,2681468,cls,"{'NDCG@1': 0.34878, 'NDCG@3': 0.45609, 'NDCG@5': 0.49928...","{'MAP@1': 0.31216, 'MAP@3': 0.41763, 'MAP@5': 0.4435, 'M...","{'Recall@1': 0.31216, 'Recall@3': 0.53546, 'Recall@5': 0...","{'P@1': 0.34878, 'P@3': 0.20616, 'P@5': 0.14861, 'P@10':..."
5,facebook/dragon-roberta-query-encoder,facebook/dragon-roberta-context-encoder,2681468,cls,"{'NDCG@1': 0.36269, 'NDCG@3': 0.47072, 'NDCG@5': 0.50976...","{'MAP@1': 0.3258, 'MAP@3': 0.43229, 'MAP@5': 0.4559, 'MA...","{'Recall@1': 0.3258, 'Recall@3': 0.5491, 'Recall@5': 0.6...","{'P@1': 0.36269, 'P@3': 0.21089, 'P@5': 0.14948, 'P@10':..."


In [5]:
pd.set_option('display.max_colwidth', 60)
df = pd.DataFrame(results)
df["Recall@10"] = df["eval_recall"].apply(lambda x: x["Recall@10"])
df["nDCG@10"] = df["eval_ndcg"].apply(lambda x: x["NDCG@10"])
df = df.sort_values("nDCG@10", ascending=False)
df

Unnamed: 0,query_model,context_model,corpus_size,pooling,eval_ndcg,eval_map,eval_recall,eval_precision,Recall@10,nDCG@10
5,facebook/dragon-roberta-query-encoder,facebook/dragon-roberta-context-encoder,2681468,cls,"{'NDCG@1': 0.36269, 'NDCG@3': 0.47072, 'NDCG@5': 0.50976...","{'MAP@1': 0.3258, 'MAP@3': 0.43229, 'MAP@5': 0.4559, 'MA...","{'Recall@1': 0.3258, 'Recall@3': 0.5491, 'Recall@5': 0.6...","{'P@1': 0.36269, 'P@3': 0.21089, 'P@5': 0.14948, 'P@10':...",0.74664,0.54731
4,facebook/dragon-plus-query-encoder,facebook/dragon-plus-context-encoder,2681468,cls,"{'NDCG@1': 0.34878, 'NDCG@3': 0.45609, 'NDCG@5': 0.49928...","{'MAP@1': 0.31216, 'MAP@3': 0.41763, 'MAP@5': 0.4435, 'M...","{'Recall@1': 0.31216, 'Recall@3': 0.53546, 'Recall@5': 0...","{'P@1': 0.34878, 'P@3': 0.20616, 'P@5': 0.14861, 'P@10':...",0.74068,0.53571
0,OpenMatch/cocodr-base-msmarco,OpenMatch/cocodr-base-msmarco,2681468,cls,"{'NDCG@1': 0.32184, 'NDCG@3': 0.42313, 'NDCG@5': 0.46486...","{'MAP@1': 0.28438, 'MAP@3': 0.38558, 'MAP@5': 0.41034, '...","{'Recall@1': 0.28438, 'Recall@3': 0.49998, 'Recall@5': 0...","{'P@1': 0.32184, 'P@3': 0.1939, 'P@5': 0.1405, 'P@10': 0...",0.70606,0.50248
2,facebook/contriever-msmarco,facebook/contriever-msmarco,2681468,avg,"{'NDCG@1': 0.30736, 'NDCG@3': 0.41327, 'NDCG@5': 0.45537...","{'MAP@1': 0.27141, 'MAP@3': 0.37456, 'MAP@5': 0.3995, 'M...","{'Recall@1': 0.27141, 'Recall@3': 0.49355, 'Recall@5': 0...","{'P@1': 0.30736, 'P@3': 0.19119, 'P@5': 0.13853, 'P@10':...",0.71468,0.49775
1,Shitao/RetroMAE_MSMARCO_finetune,Shitao/RetroMAE_MSMARCO_finetune,2681468,cls,"{'NDCG@1': 0.30881, 'NDCG@3': 0.40739, 'NDCG@5': 0.44582...","{'MAP@1': 0.27815, 'MAP@3': 0.37295, 'MAP@5': 0.39585, '...","{'Recall@1': 0.27815, 'Recall@3': 0.4782, 'Recall@5': 0....","{'P@1': 0.30881, 'P@3': 0.18299, 'P@5': 0.13204, 'P@10':...",0.67586,0.48315
3,facebook/contriever,facebook/contriever,2681468,avg,"{'NDCG@1': 0.12601, 'NDCG@3': 0.18426, 'NDCG@5': 0.21274...","{'MAP@1': 0.11061, 'MAP@3': 0.16281, 'MAP@5': 0.17905, '...","{'Recall@1': 0.11061, 'Recall@3': 0.22663, 'Recall@5': 0...","{'P@1': 0.12601, 'P@3': 0.08864, 'P@5': 0.069, 'P@10': 0...",0.41292,0.25367


In [8]:
df = df[['query_model', 'pooling', 'nDCG@10', 'Recall@10']]
# df = df.iloc[:6]

def clean_table(table_str):
    mappings = {
        "attention": "Attention",
    }
    model_mappings = {
        "OpenMatch/cocodr-base-msmarco": ("COCO-DR", "Base MSMARCO"),
        "Shitao/RetroMAE\_MSMARCO\_finetune": ("RetroMAE", "MSMARCO FT"),
        "Shitao/RetroMAE\_MSMARCO": ("RetroMAE", "MSMARCO"),
        "Shitao/RetroMAE": ("RetroMAE", ""),
        "facebook/contriever-msmarco": ("Contriever", "MSMARCO"),
        "facebook/contriever": ("Contriever", ""),
        "facebook/dragon-plus-query-encoder": ("Dragon+", ""),
        "facebook/dragon-roberta-query-encoder": ("Dragon RoBERTa", ""),
    }
    back = "\\"
    raw_mappings = {
        # r"Method & Selection &  &  &  &  &  &  &  &  &  &  \\": "",
        "query\_model": "Model",
        "llllrrrrrr": r"p{1.2cm}p{1.2cm}p{1.2cm}p{1.2cm}rrrrrr",
        r"\cline{1-10} \cline{2-10} \cline{3-10}": "\\midrule",
    }
    for k, v in mappings.items():
        key = k.replace('_', '\\_')
        table_str = table_str.replace("\\textbf{" + key + "}", "\\textsc{" + v + "}")
    for k, v in model_mappings.items():
        key = k
        value = v[0] + r" " + v[1]
        table_str = table_str.replace(key, value)
    for k, v in raw_mappings.items():
        table_str = table_str.replace(k, v)
    return table_str
print(clean_table(df.round(3).to_latex(float_format="%.2f", bold_rows=True, index=False).replace("_", "\_")))


df

\begin{tabular}{llrr}
\toprule
Model & pooling & nDCG@10 & Recall@10 \\
\midrule
Dragon RoBERTa  & cls & 0.55 & 0.75 \\
Dragon+  & cls & 0.54 & 0.74 \\
COCO-DR Base MSMARCO & cls & 0.50 & 0.71 \\
Contriever MSMARCO & avg & 0.50 & 0.71 \\
RetroMAE MSMARCO FT & cls & 0.48 & 0.68 \\
Contriever  & avg & 0.25 & 0.41 \\
\bottomrule
\end{tabular}



  "Shitao/RetroMAE\_MSMARCO\_finetune": ("RetroMAE", "MSMARCO FT"),
  "Shitao/RetroMAE\_MSMARCO": ("RetroMAE", "MSMARCO"),
  "query\_model": "Model",
  print(clean_table(df.round(3).to_latex(float_format="%.2f", bold_rows=True, index=False).replace("_", "\_")))


Unnamed: 0,query_model,pooling,nDCG@10,Recall@10
5,facebook/dragon-roberta-query-encoder,cls,0.54731,0.74664
4,facebook/dragon-plus-query-encoder,cls,0.53571,0.74068
0,OpenMatch/cocodr-base-msmarco,cls,0.50248,0.70606
2,facebook/contriever-msmarco,avg,0.49775,0.71468
1,Shitao/RetroMAE_MSMARCO_finetune,cls,0.48315,0.67586
3,facebook/contriever,avg,0.25367,0.41292
