In [1]:
import os
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from tqdm.auto import tqdm
import huggingface_hub as hf
from dotenv import load_dotenv
import matplotlib.pyplot as plt
from typing import List, Dict, Union, Tuple
from transformers import AutoTokenizer, AutoModel

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 256)

plt.style.use('seaborn-v0_8')
load_dotenv()
hf.login(os.environ["HF_TOKEN"])
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
print("CUDA_VISIBLE_DEVICES:", os.environ["CUDA_VISIBLE_DEVICES"], "HF_HOME:", os.environ["HF_HOME"])

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


CUDA_VISIBLE_DEVICES: 0 HF_HOME: /local1/mohsenfayyaz/.hfcache/


In [2]:
# DATASET = "re-docred_facebook--contriever-msmarco_7170.pkl"
DATASET = "re-docred_facebook--dragon-plus-query-encoder_7170.pkl"
# DATASET = "re-docred_OpenMatch--cocodr-base-msmarco_7170.pkl.gz"

hf.hf_hub_download(repo_id="Retriever-Contextualization/datasets", filename=f"results/{DATASET}", repo_type="dataset", local_dir="hf/")
df_raw = pd.read_pickle(f"./hf/results/{DATASET}")
print(df_raw.attrs)
df_raw.head(1)

{'model': 'facebook/dragon-plus-query-encoder', 'query_model': 'facebook/dragon-plus-query-encoder', 'context_model': 'facebook/dragon-plus-context-encoder', 'pooling': 'cls', 'dataset': 're-docred', 'corpus_size': 105925, 'eval': {'ndcg': {'NDCG@1': 0.47685, 'NDCG@3': 0.52523, 'NDCG@5': 0.53646, 'NDCG@10': 0.54955, 'NDCG@100': 0.58002, 'NDCG@1000': 0.59556}, 'map': {'MAP@1': 0.47685, 'MAP@3': 0.51341, 'MAP@5': 0.51959, 'MAP@10': 0.52496, 'MAP@100': 0.53058, 'MAP@1000': 0.53109}, 'recall': {'Recall@1': 0.47685, 'Recall@3': 0.55941, 'Recall@5': 0.58689, 'Recall@10': 0.62748, 'Recall@100': 0.77741, 'Recall@1000': 0.90349}, 'precision': {'P@1': 0.47685, 'P@3': 0.18647, 'P@5': 0.11738, 'P@10': 0.06275, 'P@100': 0.00777, 'P@1000': 0.0009}}}


Unnamed: 0,query_id,query,gold_docs,gold_docs_text,scores_stats,scores_gold,scores_1000,predicted_docs_text_10,id,title,vertexSet,labels,sents,split,label,label_idx,head_entity,tail_entity,head_entity_names,tail_entity_names,head_entity_longest_name,tail_entity_longest_name,head_entity_types,tail_entity_types,evidence_sent_ids,evidence_sents,head_entity_in_evidence,tail_entity_in_evidence,relation,relation_name,query_question,duplicate_titles_len,duplicate_titles,hit_rank,gold_doc,gold_doc_title,gold_doc_text,gold_doc_score,pred_doc,pred_doc_title,pred_doc_text,pred_doc_score,gold_doc_len,pred_doc_len,query_decompx_tokens,query_decompx_tokenizer_word_ids,query_decompx_cls_or_mean_pooled,query_decompx_tokens_dot_scores,query_decompx_decompx_last_layer_pooled,gold_doc_decompx_tokens,gold_doc_decompx_tokenizer_word_ids,gold_doc_decompx_cls_or_mean_pooled,gold_doc_decompx_tokens_dot_scores,gold_doc_decompx_decompx_last_layer_pooled,pred_doc_decompx_tokens,pred_doc_decompx_tokenizer_word_ids,pred_doc_decompx_cls_or_mean_pooled,pred_doc_decompx_tokens_dot_scores,pred_doc_decompx_decompx_last_layer_pooled
0,test0,When was Loud Tour published?,[Loud Tour],"{'Loud Tour': {'text': 'The Loud Tour was the fourth overall and third world concert tour by Barbadian recording artist Rihanna . Performing in over twenty countries in the Americas and Europe , the tour was launched in support of Rihanna 's fifth stud...","{'len': 1000, 'max': 390.3378601074219, 'min': 377.525390625, 'std': 1.243663421340353, 'mean': 378.77503692626954, 'median': 378.4281463623047}",{'Loud Tour': 390.3378601074219},"{'Loud Tour': 390.3378601074219, 'Loud'n'proud': 385.71905517578125, 'Poetry Bus Tour': 385.4292907714844, 'Live &amp; Loud': 384.18218994140625, 'The Loudest Engine': 384.0265808105469, 'Young Wild Things Tour': 383.8572998046875, 'Guitar Rock Tour': ...","{'Loud Tour': {'text': 'The Loud Tour was the fourth overall and third world concert tour by Barbadian recording artist Rihanna . Performing in over twenty countries in the Americas and Europe , the tour was launched in support of Rihanna 's fifth stud...",test0,Loud Tour,"[[{'name': 'Loud', 'pos': [23, 24], 'sent_id': 1, 'type': 'MISC', 'global_pos': [41, 41], 'index': '0_0'}, {'name': 'Loud Tour', 'pos': [1, 3], 'sent_id': 6, 'type': 'MISC', 'global_pos': [128, 128], 'index': '0_1'}, {'name': 'Loud Tour', 'pos': [1, 3]...","[{'r': 'P577', 'h': 0, 't': 6, 'evidence': [1]}, {'r': 'P175', 'h': 0, 't': 2, 'evidence': [0, 1]}, {'r': 'P131', 'h': 10, 't': 8, 'evidence': [4]}, {'r': 'P17', 'h': 8, 't': 7, 'evidence': [3, 4]}, {'r': 'P17', 'h': 10, 't': 7, 'evidence': [3, 4]}, {'...","[[The, Loud, Tour, was, the, fourth, overall, and, third, world, concert, tour, by, Barbadian, recording, artist, Rihanna, .], [Performing, in, over, twenty, countries, in, the, Americas, and, Europe, ,, the, tour, was, launched, in, support, of, Rihan...",test,"{'r': 'P577', 'h': 0, 't': 6, 'evidence': [1]}",0,"[{'name': 'Loud', 'pos': [23, 24], 'sent_id': 1, 'type': 'MISC', 'global_pos': [41, 41], 'index': '0_0'}, {'name': 'Loud Tour', 'pos': [1, 3], 'sent_id': 6, 'type': 'MISC', 'global_pos': [128, 128], 'index': '0_1'}, {'name': 'Loud Tour', 'pos': [1, 3],...","[{'pos': [25, 26], 'type': 'TIME', 'sent_id': 1, 'name': '2010', 'global_pos': [43, 43], 'index': '6_0'}]","{Loud Tour, Loud}",{2010},Loud Tour,2010,{MISC},{TIME},[1],"[[Performing, in, over, twenty, countries, in, the, Americas, and, Europe, ,, the, tour, was, launched, in, support, of, Rihanna, 's, fifth, studio, album, Loud, (, 2010, ), .]]","[{'name': 'Loud', 'pos': [23, 24], 'sent_id': 1, 'type': 'MISC', 'global_pos': [41, 41], 'index': '0_0'}]","[{'pos': [25, 26], 'type': 'TIME', 'sent_id': 1, 'name': '2010', 'global_pos': [43, 43], 'index': '6_0'}]",P577,publication date,When was Loud Tour published?,0,{},1.0,"Loud Tour The Loud Tour was the fourth overall and third world concert tour by Barbadian recording artist Rihanna . Performing in over twenty countries in the Americas and Europe , the tour was launched in support of Rihanna 's fifth studio album Loud ...",Loud Tour,"The Loud Tour was the fourth overall and third world concert tour by Barbadian recording artist Rihanna . Performing in over twenty countries in the Americas and Europe , the tour was launched in support of Rihanna 's fifth studio album Loud ( 2010 ) ....",390.33786,"Loud Tour The Loud Tour was the fourth overall and third world concert tour by Barbadian recording artist Rihanna . Performing in over twenty countries in the Americas and Europe , the tour was launched in support of Rihanna 's fifth studio album Loud ...",Loud Tour,"The Loud Tour was the fourth overall and third world concert tour by Barbadian recording artist Rihanna . Performing in over twenty countries in the Americas and Europe , the tour was launched in support of Rihanna 's fifth studio album Loud ( 2010 ) ....",390.33786,142,142,"[[CLS], when, was, loud, tour, published, ?, [SEP]]","[None, 0, 1, 2, 3, 4, 4, None]","[-0.17805682, -0.3927267, 0.34883702, -0.38739026, -0.23735791, -0.19460969, 0.21865264, 0.068975255, -0.1592264, 0.18711175, -0.20565934, 0.003034133, -0.18440822, 0.40548998, -0.4549966, 0.51666415, 0.09620502, -0.1836627, -0.4205021, -0.010630409, 0...","[2.2196622, 6.71451, 0.9866385, 58.316944, 37.08578, 4.3126516, 1.2738111, -1.2307678]","[[0.0026502553, 0.044497166, 0.009840142, -0.029498188, 0.047593728, 0.0005243204, 0.089234896, -0.058340102, -0.0002567456, -0.06561515, 0.012288873, -0.018892672, 0.0068592615, 0.031180702, 0.027442973, -0.06405719, 0.007814868, -0.030438174, 0.02620...","[[CLS], loud, tour, the, loud, tour, was, the, fourth, overall, and, third, world, concert, tour, by, bar, ##bad, ##ian, recording, artist, rihanna, ., performing, in, over, twenty, countries, in, the, americas, and, europe, ,, the, tour, was, launched...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 54, 55, 56, 57, 58, 59,...","[-0.7096514, -0.43747085, 2.078466, -0.8606712, 2.3640666, 0.67811525, 3.0262432, 0.6547275, -1.481939, -2.838817, -1.2552446, 1.0732918, -3.318883, 3.0607197, -0.41772836, 3.4470546, 3.6913419, 0.77499884, 1.0027949, -1.8230458, 0.37280822, -1.2724396...","[650.5565, 112.46794, 110.70713, 35.217003, 88.90661, 93.24184, 93.337906, 43.745255, 70.52942, 71.49595, 12.454085, 71.52813, 16.037241, 15.657524, 53.929047, 21.389343, 5.0486135, 39.502014, 38.76303, 1.5653663, 10.186192, 42.767452, 3.310997, 14.803...","[[-0.06098142, 0.030208647, 0.35368052, -0.15786159, 0.43346453, 0.0317666, 0.49806064, 0.11205646, -0.21001092, -0.54779494, -0.16660528, 0.20015034, -0.5065915, 0.43021473, -0.07829579, 0.51261973, 0.5002409, 0.08070024, 0.15372154, -0.20520967, -0.0...","[[CLS], loud, tour, the, loud, tour, was, the, fourth, overall, and, third, world, concert, tour, by, bar, ##bad, ##ian, recording, artist, rihanna, ., performing, in, over, twenty, countries, in, the, americas, and, europe, ,, the, tour, was, launched...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 54, 55, 56, 57, 58, 59,...","[-0.7096514, -0.43747085, 2.078466, -0.8606712, 2.3640666, 0.67811525, 3.0262432, 0.6547275, -1.481939, -2.838817, -1.2552446, 1.0732918, -3.318883, 3.0607197, -0.41772836, 3.4470546, 3.6913419, 0.77499884, 1.0027949, -1.8230458, 0.37280822, -1.2724396...","[650.5565, 112.46794, 110.70713, 35.217003, 88.90661, 93.24184, 93.337906, 43.745255, 70.52942, 71.49595, 12.454085, 71.52813, 16.037241, 15.657524, 53.929047, 21.389343, 5.0486135, 39.502014, 38.76303, 1.5653663, 10.186192, 42.767452, 3.310997, 14.803...","[[-0.06098142, 0.030208647, 0.35368052, -0.15786159, 0.43346453, 0.0317666, 0.49806064, 0.11205646, -0.21001092, -0.54779494, -0.16660528, 0.20015034, -0.5065915, 0.43021473, -0.07829579, 0.51261973, 0.5002409, 0.08070024, 0.15372154, -0.20520967, -0.0..."


In [3]:
df = df_raw.copy()
df = df[df["evidence_sent_ids"].str.len() == 1]  # 1 Evidence
df = df[df["evidence_sents"].str.len() == 1]  # 1 Evidence Sentence
df = df[df["head_entity"].str.len() == 1]  # 1 Head
df = df[df["head_entity_in_evidence"].str.len() >= 1]  # 1 Head in Evidence
df = df[df["head_entity_names"].str.len() == 1]  # All heads have the same name
print(len(df))

# Filter Repeated Labels (Only 1 h->t)
def not_repeated_label(label, labels):
    repeat_count = 0
    for l in labels:
        if l['h'] == label['h'] and l['t'] == label['t']:
            repeat_count += 1
    return True if repeat_count == 1 else False
df["repeated_label"] = df.apply(lambda r: not_repeated_label(r["label"], r["labels"]), axis=1)
df = df[df["repeated_label"]]
print(len(df))

df = df.sample(frac=1, random_state=0)
df = df.drop_duplicates(subset=["title"])
print(len(df))

df = df.sample(250, random_state=0)
print(len(df))

def flatten(xss):
    return [x for xs in xss for x in xs]

df["sents_complete"] = df["sents"].apply(lambda x: " ".join(flatten(x)))
df["sents_evidence_single"] = df["evidence_sents"].apply(lambda x: " ".join(flatten(x)))
df["sents_evidence_redundant"] = df.apply(lambda r: r["sents_evidence_single"] + " " + df[df['title'] != r['title']].sample(1, random_state=0)["sents_complete"].values[0], axis=1)
df["sents_evidence_doc"] = df.apply(lambda r: r["sents_evidence_single"] + " " + r["sents_complete"].replace(r["sents_evidence_single"], ""), axis=1)

sents_cols = []
### evidence > evidence + redundant long text
sents_cols.append(f"sents_evidence_single")
# sents_cols.append(f"sents_evidence_redundant")  # 35
sents_cols.append(f"sents_evidence_doc")  # 17

print(df["title"].value_counts())

save_cols = [
    "query", "gold_docs", "gold_docs_text", "scores_stats", "scores_gold", "predicted_docs_text_10", "id", "title", "vertexSet", 
    "labels", "sents", "split", "label", "label_idx", "head_entity", 
    "tail_entity", "head_entity_names", "tail_entity_names", 
    "head_entity_longest_name", "tail_entity_longest_name", 
    "head_entity_types", "tail_entity_types", "evidence_sent_ids", 
    "evidence_sents", "head_entity_in_evidence", "tail_entity_in_evidence", 
    "relation", "relation_name", "query_question", "duplicate_titles_len", 
    "duplicate_titles", "hit_rank", "gold_doc", "gold_doc_title", 
    "gold_doc_text", "gold_doc_score", "pred_doc", "pred_doc_title", 
    "pred_doc_text", "pred_doc_score", "gold_doc_len", "pred_doc_len",
] + sents_cols
df[save_cols].to_json("dataset/brevity_bias.jsonl", orient="records", lines=True)

df.sample(2)

1819
1247
500
250
title
House of Angels                  1
Usain Bolt Sports Complex        1
USS Lyndon B. Johnson            1
New Haven Harbor                 1
Black Lake (Michigan)            1
                                ..
Louis Lombardi                   1
John Ripley (USMC)               1
Brother Man                      1
Township High School District    1
Across the Black Waters          1
Name: count, Length: 250, dtype: int64


Unnamed: 0,query_id,query,gold_docs,gold_docs_text,scores_stats,scores_gold,scores_1000,predicted_docs_text_10,id,title,vertexSet,labels,sents,split,label,label_idx,head_entity,tail_entity,head_entity_names,tail_entity_names,head_entity_longest_name,tail_entity_longest_name,head_entity_types,tail_entity_types,evidence_sent_ids,evidence_sents,head_entity_in_evidence,tail_entity_in_evidence,relation,relation_name,query_question,duplicate_titles_len,duplicate_titles,hit_rank,gold_doc,gold_doc_title,gold_doc_text,gold_doc_score,pred_doc,pred_doc_title,pred_doc_text,pred_doc_score,gold_doc_len,pred_doc_len,query_decompx_tokens,query_decompx_tokenizer_word_ids,query_decompx_cls_or_mean_pooled,query_decompx_tokens_dot_scores,query_decompx_decompx_last_layer_pooled,gold_doc_decompx_tokens,gold_doc_decompx_tokenizer_word_ids,gold_doc_decompx_cls_or_mean_pooled,gold_doc_decompx_tokens_dot_scores,gold_doc_decompx_decompx_last_layer_pooled,pred_doc_decompx_tokens,pred_doc_decompx_tokenizer_word_ids,pred_doc_decompx_cls_or_mean_pooled,pred_doc_decompx_tokens_dot_scores,pred_doc_decompx_decompx_last_layer_pooled,repeated_label,sents_complete,sents_evidence_single,sents_evidence_redundant,sents_evidence_doc
2728,test12647,When was Bill French born?,[Bill Warner (writer)],"{'Bill Warner (writer)': {'text': 'Bill French ( born 1941 , United States ) , known by the pseudonym Bill Warner , is a critic of Islam , a writer and the founder of the Center for the Study of Political Islam . He is a former Tennessee State Universi...","{'len': 1000, 'max': 368.3065185546875, 'min': 358.9645080566406, 'std': 1.0535337829940699, 'mean': 360.1074428100586, 'median': 359.8053741455078}",{'Bill Warner (writer)': 365.4850158691406},"{'Frederick John French': 368.3065185546875, 'Robert French': 367.2232666015625, 'Ray H. French': 366.0547180175781, 'Bill Warner (writer)': 365.4850158691406, 'Adam Billaut': 364.8941955566406, 'Bill Fraccio': 364.8255920410156, 'Antoine Omer Talon': ...","{'Frederick John French': {'text': 'Frederick John French , ( January 18 , 1847 - 1924 ) was an Ontario lawyer and political figure . He represented Grenville South and then Grenville in the Legislative Assembly of Ontario as a Conservative member from...",test12647,Bill Warner (writer),"[[{'name': 'Bill French', 'pos': [0, 2], 'sent_id': 0, 'type': 'PER', 'global_pos': [0, 0], 'index': '0_0'}], [{'type': 'TIME', 'pos': [4, 5], 'name': '1941', 'sent_id': 0, 'global_pos': [4, 4], 'index': '1_0'}], [{'name': 'United States', 'pos': [6, 8...","[{'r': 'P569', 'h': 0, 't': 1, 'evidence': [0]}, {'r': 'P27', 'h': 0, 't': 2, 'evidence': [0]}, {'r': 'P27', 'h': 3, 't': 2, 'evidence': [0]}, {'r': 'P17', 'h': 3, 't': 2, 'evidence': [0]}, {'r': 'P140', 'h': 3, 't': 4, 'evidence': [0]}, {'r': 'P140', ...","[[Bill, French, (, born, 1941, ,, United, States, ), ,, known, by, the, pseudonym, Bill, Warner, ,, is, a, critic, of, Islam, ,, a, writer, and, the, founder, of, the, Center, for, the, Study, of, Political, Islam, .], [He, is, a, former, Tennessee, St...",test,"{'r': 'P569', 'h': 0, 't': 1, 'evidence': [0]}",0,"[{'name': 'Bill French', 'pos': [0, 2], 'sent_id': 0, 'type': 'PER', 'global_pos': [0, 0], 'index': '0_0'}]","[{'type': 'TIME', 'pos': [4, 5], 'name': '1941', 'sent_id': 0, 'global_pos': [4, 4], 'index': '1_0'}]",{Bill French},{1941},Bill French,1941,{PER},{TIME},[0],"[[Bill, French, (, born, 1941, ,, United, States, ), ,, known, by, the, pseudonym, Bill, Warner, ,, is, a, critic, of, Islam, ,, a, writer, and, the, founder, of, the, Center, for, the, Study, of, Political, Islam, .]]","[{'name': 'Bill French', 'pos': [0, 2], 'sent_id': 0, 'type': 'PER', 'global_pos': [0, 0], 'index': '0_0'}]","[{'type': 'TIME', 'pos': [4, 5], 'name': '1941', 'sent_id': 0, 'global_pos': [4, 4], 'index': '1_0'}]",P569,date of birth,When was Bill French born?,0,{},4.0,"Bill Warner (writer) Bill French ( born 1941 , United States ) , known by the pseudonym Bill Warner , is a critic of Islam , a writer and the founder of the Center for the Study of Political Islam . He is a former Tennessee State University physics pro...",Bill Warner (writer),"Bill French ( born 1941 , United States ) , known by the pseudonym Bill Warner , is a critic of Islam , a writer and the founder of the Center for the Study of Political Islam . He is a former Tennessee State University physics professor . He is listed...",365.485016,"Frederick John French Frederick John French , ( January 18 , 1847 - 1924 ) was an Ontario lawyer and political figure . He represented Grenville South and then Grenville in the Legislative Assembly of Ontario as a Conservative member from 1879 to 1890 ...",Frederick John French,"Frederick John French , ( January 18 , 1847 - 1924 ) was an Ontario lawyer and political figure . He represented Grenville South and then Grenville in the Legislative Assembly of Ontario as a Conservative member from 1879 to 1890 . He was born in Burri...",368.306519,183,141,"[[CLS], when, was, bill, french, born, ?, [SEP]]","[None, 0, 1, 2, 3, 4, 4, None]","[-0.087123975, -0.30652568, -0.17900296, -0.5025232, 0.0045261537, 0.30944324, -0.0018868069, -0.09478069, -0.46635157, -0.20560545, -0.1561065, -0.07506508, -0.2580581, 0.027959017, 0.18539724, 0.26456964, 0.23967221, -0.0034961996, 0.07397056, -0.333...","[-0.69223404, 3.046039, -1.8092594, 26.00676, 57.541595, 25.667818, -0.9414638, -0.36833078]","[[0.05525274, 0.0033733915, 0.04733997, -0.023129841, 0.03558308, -0.041217428, 0.076585084, -0.04876297, -0.0078774, -0.06259745, -0.016005706, -0.012722935, 0.018480552, 0.0028626486, -0.037496254, -0.081890926, -0.03738486, -0.03574433, -0.005523898...","[[CLS], bill, warner, (, writer, ), bill, french, (, born, 1941, ,, united, states, ), ,, known, by, the, pseudonym, bill, warner, ,, is, a, critic, of, islam, ,, a, writer, and, the, founder, of, the, center, for, the, study, of, political, islam, ., ...","[None, 0, 1, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 6...","[-0.3835547, 0.20104654, 1.8976009, -1.0706238, 2.1885116, 0.56244934, 2.8140144, 0.39204317, -1.0732121, -3.2005625, -1.4833462, 1.4285593, -3.5568051, 2.7645829, -0.51157373, 3.3840845, 3.571874, 0.57340467, 0.6000074, -1.5863663, -0.11256978, -1.284...","[873.3777, 22.48209, 72.67711, 63.085205, 44.042576, 40.32502, 9.930698, 18.710537, 60.18251, 98.69971, 33.700844, 4.1486444, 40.443977, 31.863424, 55.24453, 41.56901, 73.50342, 1.705921, 13.573771, 59.137005, 21.012276, 44.57038, 40.39782, 54.448235, ...","[[-0.06815151, 0.05042049, 0.47108305, -0.23331961, 0.6014231, 0.06205207, 0.58569837, 0.104989104, -0.3064388, -0.70994985, -0.253498, 0.26285374, -0.6462882, 0.5901693, -0.09085502, 0.674271, 0.71233, 0.10613178, 0.2195051, -0.27148676, 0.022813978, ...","[[CLS], frederick, john, french, frederick, john, french, ,, (, january, 18, ,, 1847, -, 1924, ), was, an, ontario, lawyer, and, political, figure, ., he, represented, gr, ##en, ##ville, south, and, then, gr, ##en, ##ville, in, the, legislative, assemb...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 25, 25, 26, 27, 28, 29, 29, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 49, 50, 50, 51, 52, 53, 54, 55, 56, 57,...","[-1.1777982, 0.5352428, 1.8575447, -0.61195457, 2.7613678, 1.1020998, 2.8921423, 0.6386914, -1.46767, -3.6072805, -1.4904162, 0.96460265, -3.2491841, 2.4434018, -0.43786564, 3.5257907, 3.6324801, 0.76826525, 0.9136227, -2.0321684, -0.1378155, -1.389258...","[1218.9647, 14.089399, 5.367064, 12.018502, 26.156895, 14.615202, 35.736176, 68.270546, 61.100594, 52.67967, 41.39744, 21.471891, 48.189766, 9.66394, 79.14222, 54.324757, 88.46645, 65.181786, 9.5254755, 97.53816, 16.560017, 24.606682, 4.579794, 16.2894...","[[-0.15239555, 0.045515217, 0.62077165, -0.28267956, 0.7992867, 0.04145582, 0.84470093, 0.17525473, -0.39432058, -0.9674504, -0.36761358, 0.36108175, -0.9399205, 0.7963315, -0.119723886, 0.91992766, 0.9560881, 0.16302425, 0.31721726, -0.47485098, 0.010...",True,"Bill French ( born 1941 , United States ) , known by the pseudonym Bill Warner , is a critic of Islam , a writer and the founder of the Center for the Study of Political Islam . He is a former Tennessee State University physics professor . He is listed...","Bill French ( born 1941 , United States ) , known by the pseudonym Bill Warner , is a critic of Islam , a writer and the founder of the Center for the Study of Political Islam .","Bill French ( born 1941 , United States ) , known by the pseudonym Bill Warner , is a critic of Islam , a writer and the founder of the Center for the Study of Political Islam . Across the Black Waters is an English novel by the Indian writer Mulk Raj ...","Bill French ( born 1941 , United States ) , known by the pseudonym Bill Warner , is a critic of Islam , a writer and the founder of the Center for the Study of Political Islam . He is a former Tennessee State University physics professor . He is liste..."
5793,validation9838,Where is 1988 Winter Olympics located?,[Bonnie Blair],"{'Bonnie Blair': {'text': 'Bonnie Kathleen Blair ( born March 18 , 1964 ) is a retired American speed skater . She is one of the top skaters of her era , and one of the most decorated athletes in Olympic history . Blair competed for the United States i...","{'len': 1000, 'max': 366.01519775390625, 'min': 356.85955810546875, 'std': 1.2988994313341742, 'mean': 358.2047901611328, 'median': 357.82530212402344}",{'Bonnie Blair': 360.8177185058594},"{'Hidy and Howdy': 366.01519775390625, 'Canadian Olympic Curling Trials': 365.96148681640625, 'Lake Placid Winter Olympic Museum': 364.78106689453125, 'Palasport Olimpico and Stadio Comunale area in Turin': 363.8551025390625, 'Bart Carpentier Alting': ...","{'Hidy and Howdy': {'text': 'Hidy and Howdy were the official mascots of the 1988 Winter Olympics in Calgary , Alberta , Canada . They were twin polar bears who wore western / cowboy style outfits . Students of Bishop Carroll High School in Calgary wer...",validation9838,Bonnie Blair,"[[{'name': 'Bonnie Kathleen Blair', 'pos': [0, 3], 'sent_id': 0, 'type': 'PER', 'global_pos': [0, 0], 'index': '0_0'}, {'name': 'Blair', 'pos': [4, 5], 'sent_id': 4, 'type': 'PER', 'global_pos': [80, 80], 'index': '0_1'}, {'name': 'Blair', 'pos': [0, 1...","[{'r': 'P569', 'h': 0, 't': 1, 'evidence': [0]}, {'r': 'P1344', 'h': 0, 't': 12, 'evidence': [6]}, {'r': 'P1344', 'h': 0, 't': 17, 'evidence': [8]}, {'r': 'P27', 'h': 0, 't': 2, 'evidence': [0, 2]}, {'r': 'P582', 'h': 12, 't': 11, 'evidence': [6]}, {'r...","[[Bonnie, Kathleen, Blair, (, born, March, 18, ,, 1964, ), is, a, retired, American, speed, skater, .], [She, is, one, of, the, top, skaters, of, her, era, ,, and, one, of, the, most, decorated, athletes, in, Olympic, history, .], [Blair, competed, for...",validation,"{'r': 'P276', 'h': 12, 't': 13, 'evidence': [6]}",6,"[{'name': '1988 Winter Olympics', 'pos': [14, 17], 'sent_id': 6, 'type': 'MISC', 'global_pos': [118, 118], 'index': '12_0'}]","[{'name': 'Calgary', 'pos': [18, 19], 'sent_id': 6, 'type': 'LOC', 'global_pos': [122, 122], 'index': '13_0'}]",{1988 Winter Olympics},{Calgary},1988 Winter Olympics,Calgary,{MISC},{LOC},[6],"[[Blair, returned, to, the, Olympics, in, 1988, competing, in, long, -, track, at, the, 1988, Winter, Olympics, in, Calgary, .]]","[{'name': '1988 Winter Olympics', 'pos': [14, 17], 'sent_id': 6, 'type': 'MISC', 'global_pos': [118, 118], 'index': '12_0'}]","[{'name': 'Calgary', 'pos': [18, 19], 'sent_id': 6, 'type': 'LOC', 'global_pos': [122, 122], 'index': '13_0'}]",P276,location,Where is 1988 Winter Olympics located?,0,{},44.0,"Bonnie Blair Bonnie Kathleen Blair ( born March 18 , 1964 ) is a retired American speed skater . She is one of the top skaters of her era , and one of the most decorated athletes in Olympic history . Blair competed for the United States in four Olympic...",Bonnie Blair,"Bonnie Kathleen Blair ( born March 18 , 1964 ) is a retired American speed skater . She is one of the top skaters of her era , and one of the most decorated athletes in Olympic history . Blair competed for the United States in four Olympics , winning f...",360.817719,"Hidy and Howdy Hidy and Howdy were the official mascots of the 1988 Winter Olympics in Calgary , Alberta , Canada . They were twin polar bears who wore western / cowboy style outfits . Students of Bishop Carroll High School in Calgary were used as perf...",Hidy and Howdy,"Hidy and Howdy were the official mascots of the 1988 Winter Olympics in Calgary , Alberta , Canada . They were twin polar bears who wore western / cowboy style outfits . Students of Bishop Carroll High School in Calgary were used as performers during H...",366.015198,232,162,"[[CLS], where, is, 1988, winter, olympics, located, ?, [SEP]]","[None, 0, 1, 2, 3, 4, 5, 5, None]","[-0.2705685, 0.21327102, 0.027267072, -0.14847289, 0.30499342, 0.16478366, 0.08824536, 0.41067123, -0.8745843, 0.07493399, 0.18861234, -0.35165995, -0.27747384, -0.03090475, -0.5381188, -0.11376917, 0.40764302, -0.57765657, 0.12356042, -0.12563798, 0.0...","[1.5313809, 6.3850822, -3.0228481, 60.157722, 10.7360935, 15.442906, 15.816189, 0.8967924, -0.7689862]","[[0.049635425, -0.009508106, 0.020345239, -0.054377608, 0.08128256, -0.048889954, 0.063257545, -0.045274526, 0.039293323, -0.0864244, -0.01598371, 0.040191233, 0.0092745405, 0.07901887, 0.040669706, 0.011983129, 0.016142886, 0.04401713, -0.05666238, -0...","[[CLS], bonnie, blair, bonnie, kathleen, blair, (, born, march, 18, ,, 1964, ), is, a, retired, american, speed, skater, ., she, is, one, of, the, top, skaters, of, her, era, ,, and, one, of, the, most, decorated, athletes, in, olympic, history, ., bla...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,...","[-0.82829577, 0.24343441, 1.9338658, -1.0560403, 2.5366588, 0.80153346, 2.7819555, 0.6526467, -1.8093178, -3.4732068, -1.1833714, 1.1258397, -3.0513625, 2.734552, -0.48000515, 3.2937336, 3.522903, 1.0635357, 1.3156575, -1.8039742, 0.19531131, -1.575042...","[825.8424, 59.323547, 57.955547, 61.328846, 8.866669, 58.849594, 70.0806, 40.544395, 44.284805, 26.192507, 18.126127, 43.87616, 55.000526, 112.31613, 77.225464, 70.577255, 29.14624, 16.16713, 18.629015, 2.5324032, 13.12344, 31.19376, 26.663967, 21.5806...","[[-0.08381261, 0.0265952, 0.44328153, -0.21312334, 0.51829356, 0.031442717, 0.61838776, 0.16994202, -0.24362597, -0.62742734, -0.21527377, 0.2326449, -0.6368765, 0.5854717, -0.07522446, 0.65526855, 0.6224634, 0.10837568, 0.20315197, -0.2831923, -0.0167...","[[CLS], hid, ##y, and, how, ##dy, hid, ##y, and, how, ##dy, were, the, official, mascot, ##s, of, the, 1988, winter, olympics, in, calgary, ,, alberta, ,, canada, ., they, were, twin, polar, bears, who, wore, western, /, cowboy, style, outfits, ., stud...","[None, 0, 0, 1, 2, 2, 3, 3, 4, 5, 5, 6, 7, 8, 9, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 48, 49, 50, 50, 51, 51, 52, 53, 54, 55, 56, ...","[-0.32474005, 0.92812675, 2.025266, -0.77648836, 2.471859, 0.5086309, 3.015669, 0.78596133, -1.1549944, -3.1919503, -1.5762548, 1.1222271, -3.3040223, 2.7097354, -0.63326174, 2.9868596, 3.873422, 0.97883713, 1.169049, -1.8874304, -0.0072711776, -1.5066...","[818.2252, 95.54701, 44.890816, 44.281597, 54.975273, 42.319206, 91.937775, 37.372707, 31.145405, 33.768864, 34.065903, 70.83093, 22.427689, 47.534775, 78.00682, 10.095488, 26.494701, 9.936909, 69.54457, 18.03001, 18.613297, 31.47308, 16.70467, 33.8499...","[[-0.056926906, -0.0014548544, 0.46632326, -0.2032531, 0.54791033, 0.033428095, 0.6148313, 0.1657125, -0.25933146, -0.64612085, -0.22267511, 0.2498698, -0.66362697, 0.5760204, -0.090066835, 0.649474, 0.6549501, 0.083352864, 0.22392829, -0.2491089, 0.02...",True,"Bonnie Kathleen Blair ( born March 18 , 1964 ) is a retired American speed skater . She is one of the top skaters of her era , and one of the most decorated athletes in Olympic history . Blair competed for the United States in four Olympics , winning f...",Blair returned to the Olympics in 1988 competing in long - track at the 1988 Winter Olympics in Calgary .,"Blair returned to the Olympics in 1988 competing in long - track at the 1988 Winter Olympics in Calgary . Across the Black Waters is an English novel by the Indian writer Mulk Raj Anand first published in 1939 . It describes the experience of Lalu , a ...","Blair returned to the Olympics in 1988 competing in long - track at the 1988 Winter Olympics in Calgary . Bonnie Kathleen Blair ( born March 18 , 1964 ) is a retired American speed skater . She is one of the top skaters of her era , and one of the most..."


In [4]:
class YourCustomDEModel:
    def __init__(self, q_model, doc_model, pooling, sep: str = " ", verbose=True, **kwargs):
        self.tokenizer = AutoTokenizer.from_pretrained(q_model)
        self.query_encoder = AutoModel.from_pretrained(q_model)
        self.context_encoder = AutoModel.from_pretrained(doc_model)
        self.pooling = pooling
        self.sep = sep
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.verbose = verbose
    
    # Write your own encoding query function (Returns: Query embeddings as numpy array)
    def encode_queries(self, queries: List[str], batch_size=16, **kwargs) -> np.ndarray:
        return self.encode_in_batch(self.query_encoder, queries, batch_size)
    
    # Write your own encoding corpus function (Returns: Document embeddings as numpy array)  
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size=16, **kwargs) -> np.ndarray:
        if type(corpus) is dict:
            sentences = [(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus['text']))]
        else:
            sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
        return self.encode_in_batch(self.context_encoder, sentences, batch_size)

    def encode_in_batch(self, model, sentences: List[str], batch_size=32, **kwargs) -> np.ndarray:
        model.to(self.device)
        all_embeddings = []
        for batch in tqdm(torch.utils.data.DataLoader(sentences, batch_size=batch_size, shuffle=False), disable=not self.verbose):
            inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors='pt', max_length=512)
            inputs = {key: val.to(self.device) for key, val in inputs.items()}
            outputs = model(**inputs)
            ### POOLING
            if self.pooling == "avg":
                embeddings = self.mean_pooling(outputs[0], inputs['attention_mask'])
            elif self.pooling == "cls":
                embeddings = outputs.last_hidden_state[:, 0, :]  # [128, 768] = [batch, emb_dim]
            else:
                raise ValueError("Pooling method not supported")
            all_embeddings.extend(embeddings.detach().cpu().numpy())
        all_embeddings = np.array(all_embeddings)
        if self.verbose: print(all_embeddings.shape)
        return all_embeddings

    def mean_pooling(self, token_embeddings, mask):
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        return sentence_embeddings

In [5]:
### RUN MODELS AND COMPUTE DOT SCORES
def digitize_col(df_col, bins) -> pd.DataFrame:
    return pd.cut(df_col, bins=bins)

cfgs = [
    ("facebook/dragon-plus-query-encoder", "facebook/dragon-plus-context-encoder", "cls"),
    ("facebook/dragon-roberta-query-encoder", "facebook/dragon-roberta-context-encoder", "cls"),
    ("facebook/contriever-msmarco", "facebook/contriever-msmarco", "avg"),
    ("facebook/contriever", "facebook/contriever", "avg"),
    ("OpenMatch/cocodr-base-msmarco", "OpenMatch/cocodr-base-msmarco", "cls"),
    ("Shitao/RetroMAE_MSMARCO_finetune", "Shitao/RetroMAE_MSMARCO_finetune", "cls"),
    
    # ("google-bert/bert-base-uncased", "google-bert/bert-base-uncased", "cls"),
    # ("FacebookAI/roberta-base", "FacebookAI/roberta-base", "cls")
    
    # ("Shitao/RetroMAE", "Shitao/RetroMAE", "cls"),
    # ("Shitao/RetroMAE_MSMARCO", "Shitao/RetroMAE_MSMARCO", "cls"),
]

plot_col_dots = []
for query_model, context_model, POOLING in tqdm(cfgs):
    dpr = YourCustomDEModel(query_model, context_model, POOLING, verbose=False)
    def to_doc_format(sentences: list):
        return [{"text": s} for s in sentences]
    query_embds = dpr.encode_queries(df['query'].to_list())
    for sent_col in tqdm(sents_cols, desc=f"{query_model}"):
        embds = dpr.encode_corpus(to_doc_format(df[sent_col].to_list()))
        embds_dot = torch.einsum("bd,bd->b", torch.tensor(query_embds), torch.tensor(embds)).cpu().numpy()
        new_col = f"{query_model}_{sent_col}_dot"
        # print(new_col)
        df[new_col] = embds_dot
        plot_col_dots.append(new_col)

df_dot = df.copy()

  0%|          | 0/6 [00:00<?, ?it/s]

facebook/dragon-plus-query-encoder:   0%|          | 0/2 [00:00<?, ?it/s]

facebook/dragon-roberta-query-encoder:   0%|          | 0/2 [00:00<?, ?it/s]

facebook/contriever-msmarco:   0%|          | 0/2 [00:00<?, ?it/s]

facebook/contriever:   0%|          | 0/2 [00:00<?, ?it/s]

OpenMatch/cocodr-base-msmarco:   0%|          | 0/2 [00:00<?, ?it/s]

Shitao/RetroMAE_MSMARCO_finetune:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
df = df_dot.copy()
print(plot_col_dots)

def standard_ttest_ppf(n, confidence_level=0.95):
    return stats.t.ppf(q=1-confidence_level, df=n-1, loc=0, scale=1)

results_df = pd.DataFrame()
for query_model, context_model, POOLING in tqdm(cfgs):
    ### T-TEST
    rdf = []
    for col1 in plot_col_dots:
        for col2 in plot_col_dots:
            if query_model == col1.split("_sents")[0] and query_model == col2.split("_sents")[0] and col1 != col2 and sents_cols[0] in col1:
                ttest = stats.ttest_rel(df[col1], df[col2])
                rdf.append({
                    "query_model": query_model,
                    "col1": col1,
                    "col2": col2,
                    "ttest_stats": ttest[0],
                    "ttest_pvalue": ttest[1],
                    "ttest_ci_low_stats": ttest.confidence_interval(confidence_level=0.95)[0],
                    "ttest_ci_high_stats": ttest.confidence_interval(confidence_level=0.95)[1],
                    "ttest_ci_low": np.abs(standard_ttest_ppf(len(df))),
                    "ttest_ci_high": np.abs(standard_ttest_ppf(len(df))),
                    "standard_ttest_ppf": standard_ttest_ppf(len(df)),
                    "acc": (df[col1] > df[col2]).mean(),
                    "mean_diff": (df[col1] - df[col2]).mean(),
                    "std_diff": (df[col1] - df[col2]).std(),
                    "n": len(df),
                })
    rdf = pd.DataFrame(rdf)
    results_df = pd.concat([results_df, rdf]).sort_values("ttest_stats", ascending=True)

### PLOT
model_mappings = {
    "OpenMatch/cocodr-base-msmarco": ("COCO-DR", "Base MSMARCO"),
    "Shitao/RetroMAE_MSMARCO_finetune": ("RetroMAE", "MSMARCO FT"),
    "Shitao/RetroMAE_MSMARCO": ("RetroMAE", "MSMARCO"),
    "Shitao/RetroMAE": ("RetroMAE", ""),
    "facebook/contriever-msmarco": ("Contriever", "MSMARCO"),
    "facebook/contriever": ("Contriever", ""),
    "facebook/dragon-plus-query-encoder": ("Dragon+", ""),
    "facebook/dragon-roberta-query-encoder": ("Dragon RoBERTa", ""),
    "google-bert/bert-base-uncased": ("BERT", "Base Uncased"),
    "FacebookAI/roberta-base": ("RoBERTa", "Base"),
}
results_df["query_model"] = results_df["query_model"].apply(lambda x: model_mappings[x][0] + " " + model_mappings[x][1])
results_df.rename(columns={"ttest_stats": "Paired t-Test Statistic", "query_model": "Model"}, inplace=True)

### PLOT T-TEST
plt.figure(figsize=(6, 4))
plt.title(f"Brevity Bias:\nIndividual Evidence vs. Evidence + Document")
results_df.to_json("results/brevity_df.json", orient="records")
ax = sns.barplot(
    data=results_df, y="Model", x="Paired t-Test Statistic", palette=sns.color_palette("RdYlGn_r", n_colors=len(results_df["Model"].unique())), hue="Model",
)
for container in ax.containers:
    ax.bar_label(container, fmt='%.2f', label_type='center', fontsize=10)
ax.errorbar(x=results_df["Paired t-Test Statistic"], y=results_df["Model"], xerr=results_df[["ttest_ci_low", "ttest_ci_high"]].T.to_numpy(), fmt="none", c="k", capsize=5, elinewidth=1, markeredgewidth=1, alpha=0.5)
plt.tight_layout()
plt.savefig("figs/brevity_ttest.pdf")
plt.show()

### PLOT T-TEST
plt.figure(figsize=(6, 4))
plt.title(f"{sents_cols[0].replace('_', ' ').capitalize()} vs. {sents_cols[1].replace('_', ' ').capitalize()}")
ax = sns.barplot(
    data=results_df, y="Model", x="acc", palette=sns.color_palette("RdYlGn_r"), hue="Model",
)
for container in ax.containers:
    ax.bar_label(container, fmt='%.2f', label_type='center', fontsize=10)
plt.tight_layout()
plt.show()
# plt.savefig("figs/tail_foil_ttest.pdf")

results_df

In [None]:
results_df[["Model", "Paired t-Test Statistic", "mean_diff", "std_diff", "n"]]

In [None]:
results_df[["Model", "Paired t-Test Statistic", "mean_diff", "std_diff", "n"]]

In [None]:
results_df["t"] = results_df.apply(lambda r: r["mean_diff"]/r["std_diff"]*np.sqrt(r['n']), axis=1)
results_df["tm"] = results_df.apply(lambda r: r["mean_diff"]/r["std_diff"], axis=1)
results_df[["Model", "Paired t-Test Statistic", "mean_diff", "std_diff", "n", "t", "tm"]]

In [None]:
# Find Example
pd.set_option('display.max_colwidth', 1800)
df = df_dot.copy()
df["diff"] = df["facebook/dragon-plus-query-encoder_sents_evidence_single_dot"] - df["facebook/dragon-plus-query-encoder_sents_evidence_doc_dot"]
df["evidence_text"] = df["evidence_sents"].apply(lambda x: " ".join(flatten(x)))
df["evidence_len"] = df["evidence_text"].apply(lambda x: len(x.split()))
print(df["evidence_len"].describe())
df = df[df["evidence_len"] < 20]
df[sents_cols + ["title", "query", "evidence_text", "gold_doc_len", "diff", "tail_entity"]].sort_values("diff", ascending=False)