In [42]:
import os
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from tqdm.auto import tqdm
import huggingface_hub as hf
from dotenv import load_dotenv
import matplotlib.pyplot as plt
from typing import List, Dict, Union, Tuple
from transformers import AutoTokenizer, AutoModel

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', 512)

plt.style.use('seaborn-v0_8')
load_dotenv()
hf.login(os.environ["HF_TOKEN"])
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
print("CUDA_VISIBLE_DEVICES:", os.environ["CUDA_VISIBLE_DEVICES"], "HF_HOME:", os.environ["HF_HOME"])

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/mohsenfayyaz/.cache/huggingface/token
Login successful
CUDA_VISIBLE_DEVICES: 1 HF_HOME: /local1/mohsenfayyaz/.hfcache/


# Download Dataset + DecompX Tensors

In [3]:
# DATASET = "re-docred_facebook--contriever-msmarco_7170.pkl"
DATASET = "re-docred_facebook--dragon-plus-query-encoder_7170.pkl"
# DATASET = "re-docred_OpenMatch--cocodr-base-msmarco_7170.pkl.gz"

hf.hf_hub_download(repo_id="Retriever-Contextualization/datasets", filename=f"results/{DATASET}", repo_type="dataset", local_dir="hf/")

'hf/results/re-docred_facebook--dragon-plus-query-encoder_7170.pkl'

In [4]:
df_raw = pd.read_pickle(f"./hf/results/{DATASET}")
print(df_raw.attrs)
df_raw.head(3)

{'model': 'facebook/dragon-plus-query-encoder', 'query_model': 'facebook/dragon-plus-query-encoder', 'context_model': 'facebook/dragon-plus-context-encoder', 'pooling': 'cls', 'dataset': 're-docred', 'corpus_size': 105925, 'eval': {'ndcg': {'NDCG@1': 0.47685, 'NDCG@3': 0.52523, 'NDCG@5': 0.53646, 'NDCG@10': 0.54955, 'NDCG@100': 0.58002, 'NDCG@1000': 0.59556}, 'map': {'MAP@1': 0.47685, 'MAP@3': 0.51341, 'MAP@5': 0.51959, 'MAP@10': 0.52496, 'MAP@100': 0.53058, 'MAP@1000': 0.53109}, 'recall': {'Recall@1': 0.47685, 'Recall@3': 0.55941, 'Recall@5': 0.58689, 'Recall@10': 0.62748, 'Recall@100': 0.77741, 'Recall@1000': 0.90349}, 'precision': {'P@1': 0.47685, 'P@3': 0.18647, 'P@5': 0.11738, 'P@10': 0.06275, 'P@100': 0.00777, 'P@1000': 0.0009}}}


Unnamed: 0,query_id,query,gold_docs,gold_docs_text,scores_stats,scores_gold,scores_1000,predicted_docs_text_10,id,title,vertexSet,labels,sents,split,label,label_idx,head_entity,tail_entity,head_entity_names,tail_entity_names,head_entity_longest_name,tail_entity_longest_name,head_entity_types,tail_entity_types,evidence_sent_ids,evidence_sents,head_entity_in_evidence,tail_entity_in_evidence,relation,relation_name,query_question,duplicate_titles_len,duplicate_titles,hit_rank,gold_doc,gold_doc_title,gold_doc_text,gold_doc_score,pred_doc,pred_doc_title,pred_doc_text,pred_doc_score,gold_doc_len,pred_doc_len,query_decompx_tokens,query_decompx_tokenizer_word_ids,query_decompx_cls_or_mean_pooled,query_decompx_tokens_dot_scores,query_decompx_decompx_last_layer_pooled,gold_doc_decompx_tokens,gold_doc_decompx_tokenizer_word_ids,gold_doc_decompx_cls_or_mean_pooled,gold_doc_decompx_tokens_dot_scores,gold_doc_decompx_decompx_last_layer_pooled,pred_doc_decompx_tokens,pred_doc_decompx_tokenizer_word_ids,pred_doc_decompx_cls_or_mean_pooled,pred_doc_decompx_tokens_dot_scores,pred_doc_decompx_decompx_last_layer_pooled
0,test0,When was Loud Tour published?,[Loud Tour],{'Loud Tour': {'text': 'The Loud Tour was the ...,"{'len': 1000, 'max': 390.3378601074219, 'min':...",{'Loud Tour': 390.3378601074219},"{'Loud Tour': 390.3378601074219, 'Loud'n'proud...",{'Loud Tour': {'text': 'The Loud Tour was the ...,test0,Loud Tour,"[[{'name': 'Loud', 'pos': [23, 24], 'sent_id':...","[{'r': 'P577', 'h': 0, 't': 6, 'evidence': [1]...","[[The, Loud, Tour, was, the, fourth, overall, ...",test,"{'r': 'P577', 'h': 0, 't': 6, 'evidence': [1]}",0,"[{'name': 'Loud', 'pos': [23, 24], 'sent_id': ...","[{'pos': [25, 26], 'type': 'TIME', 'sent_id': ...","{Loud Tour, Loud}",{2010},Loud Tour,2010,{MISC},{TIME},[1],"[[Performing, in, over, twenty, countries, in,...","[{'name': 'Loud', 'pos': [23, 24], 'sent_id': ...","[{'pos': [25, 26], 'type': 'TIME', 'sent_id': ...",P577,publication date,When was Loud Tour published?,0,{},1.0,Loud Tour The Loud Tour was the fourth overall...,Loud Tour,The Loud Tour was the fourth overall and third...,390.33786,Loud Tour The Loud Tour was the fourth overall...,Loud Tour,The Loud Tour was the fourth overall and third...,390.33786,142,142,"[[CLS], when, was, loud, tour, published, ?, [...","[None, 0, 1, 2, 3, 4, 4, None]","[-0.17805682, -0.3927267, 0.34883702, -0.38739...","[2.2196622, 6.71451, 0.9866385, 58.316944, 37....","[[0.0026502553, 0.044497166, 0.009840142, -0.0...","[[CLS], loud, tour, the, loud, tour, was, the,...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1...","[-0.7096514, -0.43747085, 2.078466, -0.8606712...","[650.5565, 112.46794, 110.70713, 35.217003, 88...","[[-0.06098142, 0.030208647, 0.35368052, -0.157...","[[CLS], loud, tour, the, loud, tour, was, the,...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1...","[-0.7096514, -0.43747085, 2.078466, -0.8606712...","[650.5565, 112.46794, 110.70713, 35.217003, 88...","[[-0.06098142, 0.030208647, 0.35368052, -0.157..."
1,test1,Who performed Loud Tour?,[Loud Tour],{'Loud Tour': {'text': 'The Loud Tour was the ...,"{'len': 1000, 'max': 398.40228271484375, 'min'...",{'Loud Tour': 398.40228271484375},"{'Loud Tour': 398.40228271484375, 'Tonnage Tou...",{'Loud Tour': {'text': 'The Loud Tour was the ...,test1,Loud Tour,"[[{'name': 'Loud', 'pos': [23, 24], 'sent_id':...","[{'r': 'P577', 'h': 0, 't': 6, 'evidence': [1]...","[[The, Loud, Tour, was, the, fourth, overall, ...",test,"{'r': 'P175', 'h': 0, 't': 2, 'evidence': [0, 1]}",1,"[{'name': 'Loud', 'pos': [23, 24], 'sent_id': ...","[{'name': 'Rihanna', 'pos': [3, 4], 'sent_id':...","{Loud Tour, Loud}",{Rihanna},Loud Tour,Rihanna,{MISC},{PER},"[0, 1]","[[The, Loud, Tour, was, the, fourth, overall, ...","[{'name': 'Loud', 'pos': [23, 24], 'sent_id': ...","[{'name': 'Rihanna', 'pos': [18, 19], 'sent_id...",P175,performer,Who performed Loud Tour?,0,{},1.0,Loud Tour The Loud Tour was the fourth overall...,Loud Tour,The Loud Tour was the fourth overall and third...,398.402283,Loud Tour The Loud Tour was the fourth overall...,Loud Tour,The Loud Tour was the fourth overall and third...,398.402283,142,142,"[[CLS], who, performed, loud, tour, ?, [SEP]]","[None, 0, 1, 2, 3, 3, None]","[-0.20367393, -0.43282467, 0.085645154, -0.127...","[1.0651449, 2.8293247, 14.008613, 61.79544, 30...","[[0.014361359, 0.08712164, 0.015172923, -0.007...","[[CLS], loud, tour, the, loud, tour, was, the,...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1...","[-0.7096514, -0.43747085, 2.078466, -0.8606712...","[650.5565, 112.46794, 110.70713, 35.217003, 88...","[[-0.06098142, 0.030208647, 0.35368052, -0.157...","[[CLS], loud, tour, the, loud, tour, was, the,...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1...","[-0.7096514, -0.43747085, 2.078466, -0.8606712...","[650.5565, 112.46794, 110.70713, 35.217003, 88...","[[-0.06098142, 0.030208647, 0.35368052, -0.157..."
2,test2,Which administrative territorial entity is The...,[Loud Tour],{'Loud Tour': {'text': 'The Loud Tour was the ...,"{'len': 1000, 'max': 379.07220458984375, 'min'...",{'Loud Tour': None},{'Olympic Delivery Authority': 379.07220458984...,{'Olympic Delivery Authority': {'text': 'The O...,test2,Loud Tour,"[[{'name': 'Loud', 'pos': [23, 24], 'sent_id':...","[{'r': 'P577', 'h': 0, 't': 6, 'evidence': [1]...","[[The, Loud, Tour, was, the, fourth, overall, ...",test,"{'r': 'P131', 'h': 10, 't': 8, 'evidence': [4]}",2,"[{'sent_id': 4, 'type': 'LOC', 'pos': [11, 14]...","[{'name': 'London', 'pos': [1, 2], 'sent_id': ...",{The O2 Arena},{London},The O2 Arena,London,{LOC},{LOC},[4],"[[In, London, ,, Rihanna, played, a, record, b...","[{'sent_id': 4, 'type': 'LOC', 'pos': [11, 14]...","[{'name': 'London', 'pos': [1, 2], 'sent_id': ...",P131,located in the administrative territorial entity,Which administrative territorial entity is The...,0,{},inf,Loud Tour The Loud Tour was the fourth overall...,Loud Tour,The Loud Tour was the fourth overall and third...,,Olympic Delivery Authority The Olympic Deliver...,Olympic Delivery Authority,The Olympic Delivery Authority ( ODA ) was a n...,379.072205,142,226,"[[CLS], which, administrative, territorial, en...","[None, 0, 1, 2, 3, 4, 5, 6, 6, 7, 8, 9, 9, None]","[-0.033545855, 0.04599817, -0.09456632, -0.027...","[6.398443, -1.6474726, 4.8608465, 11.063667, 1...","[[0.0017761212, 0.025855744, 0.023759678, -0.0...","[[CLS], loud, tour, the, loud, tour, was, the,...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 1...","[-0.7096514, -0.43747085, 2.078466, -0.8606712...","[650.5565, 112.46794, 110.70713, 35.217003, 88...","[[-0.06098142, 0.030208647, 0.35368052, -0.157...","[[CLS], olympic, delivery, authority, the, oly...","[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11...","[-0.56008214, 0.15913305, 1.5809196, -0.728217...","[771.9823, 40.083252, 107.10669, 9.972132, 17....","[[-0.06645119, 0.03153456, 0.4184458, -0.19593..."


In [40]:
df = df_raw.copy()
df = df[df["evidence_sent_ids"].str.len() == 1]  # 1 Evidence
df = df[df["head_entity_in_evidence"].str.len() == 1]  # 1 Head in Evidence
df = df[df["head_entity_names"].str.len() == 1]  # All heads have the same name
df = df[df["evidence_sents"].str.len() == 1]  # 1 Evidence Sentence
print(len(df))  # 2239

def flatten(xss):
    return [x for xs in xss for x in xs]

head_w_tail_sents = []
head_wo_tail_sents = []
for row in df.to_dict(orient="records"):
    head_w_tail = " ".join(flatten(row["evidence_sents"]))
    head_wo_tail = None
    for head_entity in row["head_entity"]:
        evidence_sent_id = row["evidence_sent_ids"][0]
        if head_entity["sent_id"] != evidence_sent_id:
            head_wo_tail = " ".join(row["sents"][head_entity["sent_id"]])
            break
    if head_wo_tail is None:
        head_w_tail_sents.append(None)
        head_wo_tail_sents.append(None)
    else:
        head_w_tail_sents.append(head_w_tail)
        head_wo_tail_sents.append(head_wo_tail)
    
df["head_w_tail_sentence"] = head_w_tail_sents
df["head_wo_tail_sentence"] = head_wo_tail_sents
df = df.dropna(subset=["head_w_tail_sentence", "head_wo_tail_sentence"])

d = df.iloc[0].to_dict()
print(d["query"])
print(d["tail_entity_in_evidence"])
print(d["head_entity_in_evidence"])
for i, d in enumerate(d["sents"]): print(i, d)

df[["query", "head_w_tail_sentence", "head_wo_tail_sentence"]]

2239
Who performed Long Hard Road Out of Hell?
[{'name': 'Sneaker Pimps', 'pos': [22, 24], 'sent_id': 0, 'type': 'ORG', 'global_pos': [22, 22], 'index': '4_0'}]
[{'name': 'Long Hard Road Out of Hell', 'pos': [1, 7], 'sent_id': 0, 'type': 'MISC', 'global_pos': [1, 1], 'index': '0_0'}]
0 ['"', 'Long', 'Hard', 'Road', 'Out', 'of', 'Hell', '"', 'is', 'a', 'song', 'by', 'American', 'rock', 'band', 'Marilyn', 'Manson', 'and', 'British', 'trip', 'hop', 'band', 'Sneaker', 'Pimps', ',', 'released', 'as', 'a', 'single', 'from', 'the', 'soundtrack', 'to', 'the', '1997', 'motion', 'picture', 'Spawn', '.']
1 ['An', 'arena', 'rock', 'and', 'gothic', 'rock', 'song', ',', '"', 'Long', 'Hard', 'Road', 'Out', 'of', 'Hell', '"', 'was', 'written', 'by', 'Marilyn', 'Manson', 'and', 'Twiggy', 'Ramirez', 'and', 'produced', 'by', 'Manson', 'and', 'Sean', 'Beavan', '.']
2 ['Its', 'lyrics', 'are', 'about', 'self', '-', 'loathing', 'and', 'its', 'title', 'is', 'derived', 'from', 'John', 'Milton', "'s", 'Paradise

Unnamed: 0,query,head_w_tail_sentence,head_wo_tail_sentence
32,Who performed Long Hard Road Out of Hell?,""" Long Hard Road Out of Hell "" is a song by American rock band Marilyn Manson and British trip hop band Sneaker Pimps , released as a single from the soundtrack to the 1997 motion picture Spawn .","An arena rock and gothic rock song , "" Long Hard Road Out of Hell "" was written by Marilyn Manson and Twiggy Ramirez and produced by Manson and Sean Beavan ."
33,What is Long Hard Road Out of Hell a part of?,""" Long Hard Road Out of Hell "" is a song by American rock band Marilyn Manson and British trip hop band Sneaker Pimps , released as a single from the soundtrack to the 1997 motion picture Spawn .","An arena rock and gothic rock song , "" Long Hard Road Out of Hell "" was written by Marilyn Manson and Twiggy Ramirez and produced by Manson and Sean Beavan ."
36,When was Spawn published?,""" Long Hard Road Out of Hell "" is a song by American rock band Marilyn Manson and British trip hop band Sneaker Pimps , released as a single from the soundtrack to the 1997 motion picture Spawn .","After the track was written , the Sneaker Pimps ' Kelli Ali was recruited to perform background vocals on it , as the Spawn soundtrack featured collaborations between hard rock artists and electronic music artists ."
37,When was Long Hard Road Out of Hell published?,""" Long Hard Road Out of Hell "" is a song by American rock band Marilyn Manson and British trip hop band Sneaker Pimps , released as a single from the soundtrack to the 1997 motion picture Spawn .","An arena rock and gothic rock song , "" Long Hard Road Out of Hell "" was written by Marilyn Manson and Twiggy Ramirez and produced by Manson and Sean Beavan ."
38,What is a notable work of Sneaker Pimps?,""" Long Hard Road Out of Hell "" is a song by American rock band Marilyn Manson and British trip hop band Sneaker Pimps , released as a single from the soundtrack to the 1997 motion picture Spawn .","After the track was written , the Sneaker Pimps ' Kelli Ali was recruited to perform background vocals on it , as the Spawn soundtrack featured collaborations between hard rock artists and electronic music artists ."
...,...,...,...
7032,Where was Kerstin Thorborg born?,"Born in Venjan , Sweden , the contralto Kerstin Thorborg was one of the best dramatic Wagnerian singers in the two decades between 1930 and 1950 .","Kerstin Thorborg ( May 19 , 1896 - April 12 , 1970 )"
7089,Which country is Cimatti associated with?,"Cimatti was an Italian manufacturer of bicycles , motorcycles and mopeds active between 1937 and 1984 .",Cimatti used two - stroke engines bought from both Moto Morini and Moto Minarelli .
7090,Which administrative territorial entity is Cimatti located in?,"Cimatti was an Italian manufacturer of bicycles , motorcycles and mopeds active between 1937 and 1984 .",Cimatti used two - stroke engines bought from both Moto Morini and Moto Minarelli .
7106,What is Line M1 a part of?,Młociny is a Warsaw Metro station serving as a northern terminus to Line M1 .,"Although there are no plans to extend Line M1 further , the station is built in such way that it will be possible to do so if need be ."


In [43]:
class YourCustomDEModel:
    def __init__(self, q_model, doc_model, pooling, sep: str = " ", **kwargs):
        self.tokenizer = AutoTokenizer.from_pretrained(q_model)
        self.query_encoder = AutoModel.from_pretrained(q_model)
        self.context_encoder = AutoModel.from_pretrained(doc_model)
        self.pooling = pooling
        self.sep = sep
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Write your own encoding query function (Returns: Query embeddings as numpy array)
    def encode_queries(self, queries: List[str], batch_size=128, **kwargs) -> np.ndarray:
        print("Q")
        print(len(queries))
        return self.encode_in_batch(self.query_encoder, queries, batch_size)
    
    # Write your own encoding corpus function (Returns: Document embeddings as numpy array)  
    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size=128, **kwargs) -> np.ndarray:
        if type(corpus) is dict:
            sentences = [(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus['text']))]
        else:
            sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
        return self.encode_in_batch(self.context_encoder, sentences, batch_size)

    def encode_in_batch(self, model, sentences: List[str], batch_size=128, **kwargs) -> np.ndarray:
        model.to(self.device)
        all_embeddings = []
        for batch in tqdm(torch.utils.data.DataLoader(sentences, batch_size=batch_size, shuffle=False)):
            inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors='pt', max_length=512)
            inputs = {key: val.to(self.device) for key, val in inputs.items()}
            outputs = model(**inputs)
            ### POOLING
            if self.pooling == "avg":
                embeddings = self.mean_pooling(outputs[0], inputs['attention_mask'])
            elif self.pooling == "cls":
                embeddings = outputs.last_hidden_state[:, 0, :]  # [128, 768] = [batch, emb_dim]
            else:
                raise ValueError("Pooling method not supported")
            all_embeddings.extend(embeddings.detach().cpu().numpy())
        all_embeddings = np.array(all_embeddings)
        print(all_embeddings.shape)
        return all_embeddings

    def mean_pooling(self, token_embeddings, mask):
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        return sentence_embeddings

In [44]:
### DRAGON
query_model = "facebook/dragon-plus-query-encoder"
context_model = "facebook/dragon-plus-context-encoder"
POOLING = "cls"

dpr = YourCustomDEModel(query_model, context_model, POOLING)

In [48]:
def to_doc_format(sentences: list):
    return [{"text": s} for s in sentences]

query_embds = dpr.encode_queries(df["query"].to_list())
head_w_tail_embds = dpr.encode_corpus(to_doc_format(df["head_w_tail_sentence"].to_list()))
head_wo_tail_embds = dpr.encode_corpus(to_doc_format(df["head_wo_tail_sentence"].to_list()))

Q
420


  0%|          | 0/4 [00:00<?, ?it/s]

(420, 768)


  0%|          | 0/4 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 34.00 MiB. GPU 0 has a total capacity of 47.33 GiB of which 14.81 MiB is free. Process 3277830 has 40.25 GiB memory in use. Including non-PyTorch memory, this process has 7.04 GiB memory in use. Of the allocated memory 6.52 GiB is allocated by PyTorch, and 226.04 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)