In [1]:
import json
import copy
from tqdm.notebook import trange, tqdm
from tabulate import tabulate
import pandas as pd

import openai

import os
import copy
import numpy as np
from scipy.spatial.distance import cosine
from scipy import stats

import seaborn as sns
import matplotlib.pyplot as plt

import torch

from transformers import AutoModel, AutoTokenizer,T5Tokenizer, T5ForConditionalGeneration
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer,DPRQuestionEncoder, DPRQuestionEncoderTokenizer

import logging
import transformers
transformers.tokenization_utils.logger.setLevel(logging.ERROR)
transformers.configuration_utils.logger.setLevel(logging.ERROR)
transformers.modeling_utils.logger.setLevel(logging.ERROR)

import random
from datasets import load_dataset


In [2]:
path = "./demo_pir_dataset.json"
    
with open(path,"r",encoding="utf-8") as f:
    datasets = json.load(f)

for k,v in datasets.items():
    print(k)
    q_len = [len(x.split()) for x in v["queries"]]
    c_len = [len(x.split()) for x in v["corpus"]]

    print("Query size and length",len(v["queries"]),sum(q_len)/len(q_len))
    print("Corpus size and length",len(v["corpus"]),sum(c_len)/len(c_len))

perspectrum
Query size and length 100 15.01
Corpus size and length 500 11.026
agnews
Query size and length 100 84.39
Corpus size and length 500 162.01
story
Query size and length 100 24.58
Corpus size and length 500 15.916
ambigqa
Query size and length 100 12.32
Corpus size and length 500 28.858
allsides
Query size and length 100 12.21
Corpus size and length 500 1075.936
exfever
Query size and length 100 49.37
Corpus size and length 500 28.378


In [3]:
# create a root query only dataset
source_datasets = {}

for data_name, dataset in datasets.items():
    # {"queries":[],"source_queries":[],"perspectives":[],"corpus":[],"key_ref":{},"query_labels":[]}
    source_datasets["source_"+data_name] = {"corpus":dataset["corpus"],"queries":[],"source_queries":[],"perspectives":[],"key_ref":{},"query_labels":[]}
    
    reverse_source_query_dic = {}
    
    for i, query in enumerate(dataset["source_queries"]):
        if query not in list(reverse_source_query_dic.keys()):
            query_id = str(len(source_datasets["source_"+data_name]["queries"]))
            reverse_source_query_dic[query] = query_id
            source_datasets["source_"+data_name]["queries"].append(query)
            source_datasets["source_"+data_name]["source_queries"].append(query)
            source_datasets["source_"+data_name]["perspectives"].append("none")
            source_datasets["source_"+data_name]["query_labels"].append("none")
            source_datasets["source_"+data_name]["key_ref"][query_id] = dataset["key_ref"][str(i)]
        else:
            # this source query already exists
            source_datasets["source_"+data_name]["key_ref"][str(reverse_source_query_dic[query])].extend(dataset["key_ref"][str(i)])

        
for k,v in source_datasets.items():
    print(k)
    print(len(v["key_ref"].keys()))
    
    q_len = [len(x.split()) for x in v["queries"]]
    c_len = [len(x.split()) for x in v["corpus"]]

    print("Query size and length",len(v["queries"]),sum(q_len)/len(q_len))
    print("Corpus size and length",len(v["corpus"]),sum(c_len)/len(c_len))    

source_perspectrum
16
Query size and length 16 7.9375
Corpus size and length 500 11.026
source_agnews
50
Query size and length 50 69.86
Corpus size and length 500 162.01
source_story
50
Query size and length 50 13.08
Corpus size and length 500 15.916
source_ambigqa
26
Query size and length 26 9.23076923076923
Corpus size and length 500 28.858
source_allsides
17
Query size and length 17 1.588235294117647
Corpus size and length 500 1075.936
source_exfever
34
Query size and length 34 40.529411764705884
Corpus size and length 500 28.378


In [4]:
def evaluation(key_ref, corpus_scores, query_labels, dataset_name):
    # evaluation of a dataset    
    recall_threshold = [1,5,10]
    recall_results = [0 for thresh in recall_threshold]
    
    if "source" in dataset_name:
        parts = ["none"]
    else:
        if dataset_name == "perspectrum":
            parts = ["support","undermine","general"]
        elif dataset_name == "agnews":
            parts = ["subtopic", "location"]
        elif dataset_name == "story":
            parts = ["analogy", "entity"]
        elif dataset_name == "ambigqa":
            parts = ["perspective"]
        elif dataset_name == "allsides":
            parts = ["left","right","center"]
        elif dataset_name == "exfever":
            parts = ["SUPPORT","REFUTE","NOT ENOUGH INFO"]
    
    parts_size = [0 for x in parts]
        
    for lb in query_labels:
        parts_size[parts.index(lb)] += 1
            
    partial_recall_results = []
    for i in range(len(parts)):
        partial_recall_results.append([0 for thresh in recall_threshold])

    
    for k,v in key_ref.items():
        for j, thresh in enumerate(recall_threshold):
            # important: find one is ok, this can be modified
            ranked_scores = (-np.array(corpus_scores[int(k)])).argsort()[:thresh]
            
            
            indicator = 0
            try:
                for index in v:
                    if index in ranked_scores:
                        indicator = 1 
            except:
                for index in [v]:
                    if index in ranked_scores:
                        indicator = 1                
            recall_results[j] += indicator
            partial_recall_results[parts.index(query_labels[int(k)])][j] += indicator
    
    final_results = [result/len(key_ref.items()) for result in recall_results]
        
    print("overall")
    for i, thresh in enumerate(recall_threshold):
        print("Recall@"+str(thresh)+":",final_results[i])
        
    macro_threshs = [[] for x in recall_threshold]
    
    for t, recall_results in enumerate(partial_recall_results):
        print(parts[t])
        final_results = [result/parts_size[t] for result in recall_results]
        
        for i, thresh in enumerate(recall_threshold):
            print("Recall@"+str(thresh)+":",final_results[i])
            macro_threshs[i].append(final_results[i])
                
    print("macro_average")
    for i, thresh in enumerate(recall_threshold):
        print("Recall@"+str(thresh)+":",sum(macro_threshs[i])/len(macro_threshs[i]))
                
                    

# BM25 and BERTScore
from rank_bm25 import BM25Okapi
from evaluate import load


import logging
import transformers
transformers.tokenization_utils.logger.setLevel(logging.ERROR)
transformers.configuration_utils.logger.setLevel(logging.ERROR)
transformers.modeling_utils.logger.setLevel(logging.ERROR)


def bm25_main(datasets):
    # corpuses,key_refs = corpus_building(datasets)

    for k,v in datasets.items():
        print("we are working on:",k)
        
        queries = v["queries"]
        corpus = v["corpus"]
        key_ref = v["key_ref"]
        query_labels = v["query_labels"]
        
        tokenized_corpus = [doc.split(" ") for doc in corpus]
        bm25 = BM25Okapi(tokenized_corpus)

        corpus_scores = []

        for query in tqdm(queries):
            # query = item["query"]
            tokenized_query = query.split(" ")
            doc_scores = bm25.get_scores(tokenized_query)
            corpus_scores.append(doc_scores)
        
        with open("bm25_"+k+"_scores.json","w",encoding="utf-8") as f:
            json.dump([x.tolist() for x in corpus_scores],f)
        
        evaluation(key_ref, corpus_scores, query_labels, k)
        print()
        
        
def extract_layer_cls(embeddings,layer):
    rep = []
    this_layer_embeddings = embeddings[layer]
    for emb in this_layer_embeddings:
        rep.append(emb[0])


    return rep

        
def create_embeddings(tokenizer, model, texts):
    
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
        
    # create tokenized inputs
    batch_size = 17 #29

    model.to(device)

    # naive batching
    if len(texts) < batch_size:
        inputs = tokenizer(texts,max_length=80, padding=True, truncation=True, return_tensors="pt")
        inputs = inputs.to(device)
        with torch.no_grad():
            batch_embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
            embeddings = []
            for embedding in batch_embeddings:
                embeddings.append(embedding.detach().cpu().tolist())
            del batch_embeddings
            torch.cuda.empty_cache()
    else:
        embeddings = []
        num_batch = len(texts)//batch_size

        for i in trange(num_batch+1):
            batch_start = i*batch_size
            batch_end = min(len(texts), (i+1)*batch_size)
            batch_texts = texts[batch_start:batch_end]

            inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt")
            inputs = inputs.to(device)

            with torch.no_grad():
                try:
                    batch_embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
                    embeddings.extend(batch_embeddings.detach().cpu().tolist())

                    # save cuda memory
                    del batch_embeddings
                    del inputs
                    torch.cuda.empty_cache()
                except:
                    message = "broken embeddings"

    # 25 * num_example * seq_len * 768 -> num_example * 768
    return embeddings
        
    
def dpr_main(datasets, model_name):
    # corpuses,key_refs = corpus_building(datasets)

    if model_name == "dpr":
        ctokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
        cmodel = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
        qtokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        qmodel = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        
    elif model_name in ["simcse-unsup","simcse-sup"]:
        model_mapping = {
            "simcse-unsup":"princeton-nlp/unsup-simcse-bert-base-uncased",
            "simcse-sup":"princeton-nlp/sup-simcse-bert-base-uncased",
        }
        
        ctokenizer = AutoTokenizer.from_pretrained(model_mapping[model_name])
        cmodel = AutoModel.from_pretrained(model_mapping[model_name])
        qtokenizer = AutoTokenizer.from_pretrained(model_mapping[model_name])
        qmodel = AutoModel.from_pretrained(model_mapping[model_name])

            
    qmodel.eval() 
    cmodel.eval() 
    
    for k,v in datasets.items():
        
        print("we are working on:",k)
        corpus_scores = []
        
        queries = v["queries"]
        corpus = v["corpus"]
        key_ref = v["key_ref"]
        query_labels = v["query_labels"]
            
        if model_name in ["t5","flan-t5","unifiedqa"]:
            query_embeddings = create_T5_embeddings(qtokenizer, qmodel, queries,0)
            corpus_embeddings = create_T5_embeddings(ctokenizer, cmodel, corpus, 0)  
        else:
            query_embeddings = create_embeddings(qtokenizer, qmodel, queries)
            corpus_embeddings = create_embeddings(ctokenizer, cmodel, corpus)
        
        for emb1 in tqdm(query_embeddings):
            scores = []
            for emb2 in corpus_embeddings:
                scores.append(1 - cosine(emb1, emb2))

            corpus_scores.append(scores)
            
        with open(model_name+"_"+k+"_scores.json","w",encoding="utf-8") as f:
            json.dump(corpus_scores,f)
            
        evaluation(key_ref, corpus_scores, query_labels, k)
        print()
        
        
def contriever_main(datasets):
    # corpuses,key_refs = corpus_building(datasets)
    
    def mean_pooling(token_embeddings, mask):
        token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.)
        sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
        
        return sentence_embeddings
    
    def contriever_embeddings(texts, tokenizer, model):
        # device = torch.device('cuda')
        device = torch.device('cpu')
        # create tokenized inputs
        batch_size = 29

        model.to(device)
        embeddings = []
        # naive batching
        if len(texts) < batch_size:
            inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
            outputs = model(**inputs)    
            batch_embeddings = mean_pooling(outputs[0], inputs['attention_mask'])
            for embedding in batch_embeddings:
                embeddings.append(embedding.detach().cpu().tolist())
                
            del batch_embeddings
            torch.cuda.empty_cache()
        else:
            num_batch = len(texts)//batch_size

            for i in trange(num_batch+1):
                batch_start = i*batch_size
                batch_end = min(len(texts), (i+1)*batch_size)
                batch_texts = texts[batch_start:batch_end]

                inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt")
                inputs = inputs.to(device)

                with torch.no_grad():
                    try:
                        batch_embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
                        embeddings.extend(batch_embeddings.detach().cpu().tolist())
                        del batch_embeddings
                        del inputs
                        torch.cuda.empty_cache()
                    except:
                        message = "broken embeddings"   
                        
        return embeddings
    
    
#     tokenizer = AutoTokenizer.from_pretrained('facebook/contriever-msmarco')
#     model = AutoModel.from_pretrained('facebook/contriever-msmarco')
    
    tokenizer = AutoTokenizer.from_pretrained('facebook/contriever')
    model = AutoModel.from_pretrained('facebook/contriever')
    
    
    for k,v in datasets.items():
        
        print("we are working on:",k)
        corpus_scores = []
        
        queries = v["queries"]
        corpus = v["corpus"]
        key_ref = v["key_ref"]
        query_labels = v["query_labels"]
            
        query_embeddings = contriever_embeddings(queries, tokenizer, model)
        corpus_embeddings = contriever_embeddings(corpus, tokenizer, model)
        
        for emb1 in tqdm(query_embeddings):
            scores = []
            for emb2 in corpus_embeddings:
                scores.append(1 - cosine(emb1, emb2))

            corpus_scores.append(scores)
            
        with open("contriver_"+k+"_scores.json","w",encoding="utf-8") as f:
            json.dump(corpus_scores,f)
            
        evaluation(key_ref, corpus_scores, query_labels, k)
        print()
        


# running the processing on source datasets by changing the names of the datasets
        
# print("BM25")
# bm25_main(datasets)

# model_names = ["dpr", "simcse-unsup", "simcse-sup","abs","aspire"]
# for model_name in model_names:
#     print("============", model_name, "============")
#     dpr_main(datasets, model_name) 

# print("Contriever")
# contriever_main(datasets)

# Results Processing

In [None]:
corpus_score_collection = {}

for retriever in ["bm25","dpr","simcse-sup","simcse-unsup","contriver"]:
    corpus_score_collection[retriever] = {}
    for data_name in ['agnews', 'perspectrum', 'story','allsides','exfever','ambigqa']:
        try:
            if "tart" not in retriever:
                with open("./scores/"+retriever+"_"+data_name+"_scores.json","r") as f:
                    corpus_score_collection[retriever][data_name] = json.load(f)
            else:
                with open("./scores/"+retriever+data_name+"_scores.json","r") as f:
                    corpus_score_collection[retriever][data_name] = json.load(f)                
        except:
            print(retriever,data_name)
            
        try:
            with open("./scores/"+retriever+"_source_"+data_name+"_scores.json","r") as f:
                corpus_score_collection[retriever]["source_"+data_name] = json.load(f)
        except:
            print("source",retriever,data_name)

In [None]:
import statistics

def evaluation_for_writing_mrr(datasets, corpus_scores, dataset_name, thresh):
    # evaluation of a dataset    
    
    key_ref = datasets[dataset_name]["key_ref"]
    query_labels = datasets[dataset_name]["query_labels"]
    
    mrrs = []

    for k in range(len(corpus_scores)):
        v = key_ref[str(k)]
        ranked_scores = (-np.array(corpus_scores[int(k)])).argsort()
        ranked_scores = ranked_scores.tolist()
        rr = []
        for one_correct_doc in v:
            rr.append(1/(ranked_scores.index(one_correct_doc)+1))
        
        if len(rr) == 0:
            mrrs.append(0)
        else:
            mrrs.append(sum(rr)/len(rr))
            
    return mrrs 

def evaluation_for_writing_recalls(datasets, corpus_scores, dataset_name, thresh):
    # evaluation of a dataset    
    recall_results = []
    
    key_ref = datasets[dataset_name]["key_ref"]
    query_labels = datasets[dataset_name]["query_labels"]

    for k in range(len(corpus_scores)):
        
        v = key_ref[str(k)]

        ranked_scores = (-np.array(corpus_scores[int(k)])).argsort()[:thresh]
        indicator = 0
        try:
            for index in v:
                if index in ranked_scores:
                    indicator = 1 
        except:
            for index in [v]:
                if index in ranked_scores:
                    indicator = 1     
                    
        recall_results.append(indicator)
    
    return recall_results


mrr_collection = {}

# prepare to compute perspective-aware scores: 
a mapping to group queries with the same root query {data_name: [[0,1,2],[3,4]]}
root_mapping = {}

for data_name in ['agnews','perspectrum', 'story','allsides','ambigqa',"exfever"]: 
    this_source_qs = source_datasets["source_"+data_name]["queries"]
    root_mapping[data_name] = [[] for x in this_source_qs]
    
    for i,query in enumerate(datasets[data_name]["queries"]):
        sq = datasets[data_name]["source_queries"][i]
        root_mapping[data_name][this_source_qs.index(sq)].append(i)

for retriever in ["bm25","dpr","simcse-sup","simcse-unsup","contriver"]:
    mrr_collection[retriever] = {}
    for data_name in ['agnews','story','perspectrum','ambigqa','allsides','exfever']: 
        per_corpus_score = corpus_score_collection[retriever][data_name]        
        recalls = evaluation_for_writing_recalls(datasets, per_corpus_score, data_name,5)
        
        p_mrrs = []
        for lst in root_mapping[data_name]:
            temp = []
            for x in lst:
                if x < len(recalls):
                    temp.append(recalls[x])
            p_mrrs.append(temp)
        
        p_vars = []
        for x in p_mrrs:
            if len(x) > 1:
                p_vars.append(statistics.mean(x))
            elif len(x) == 1:
                p_vars.append(0.0) #x[0]
            else:
                # equal 0
                continue
        
        mrr_collection[retriever][data_name] = sum(p_vars)/len(p_vars)

In [None]:
# Table 2 in the draft

for retriever in ["bm25","dpr","simcse-sup","simcse-unsup","contriver","tart"]:
    print(r_name_map[retriever],end=" ")
    temp_str = ""
    temp_score = []
    for data_name in ['agnews','story','perspectrum','ambigqa','allsides','exfever']: 
        temp_str += "&"+ str(round(mrr_collection[retriever][data_name]*100,1)) + " "
        temp_score.append(round(mrr_collection[retriever][data_name]*100,1))
        
    mean_score = sum(temp_score)/len(temp_score)
    print(temp_str+ "&" + str(round(mean_score,1))+ "\\\\")

In [None]:
# Sec 4.2 Exploring PIR
# collection of embeddings to enable future use

def extract_embeddings(datasets, model_name):
    # corpuses,key_refs = corpus_building(datasets)

    if model_name == "dpr":
        ctokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
        cmodel = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
        qtokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        qmodel = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        
    if model_name == "dpr-multiset":
        ctokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
        cmodel = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
        qtokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-multiset-base")
        qmodel = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-multiset-base")
        
    elif "simcse" in model_name:
        model_mapping = {
            "simcse-unsup":"princeton-nlp/unsup-simcse-bert-base-uncased",
            "simcse-sup":"princeton-nlp/sup-simcse-bert-base-uncased",
            "simcse-unsup-large":"princeton-nlp/unsup-simcse-bert-large-uncased",
            "simcse-sup-large":"princeton-nlp/sup-simcse-bert-large-uncased",
            "simcse-unsup-rb":"princeton-nlp/unsup-simcse-roberta-base",
            "simcse-sup-rb":"princeton-nlp/sup-simcse-roberta-base",
            "simcse-unsup-large-rb":"princeton-nlp/unsup-simcse-roberta-large",
            "simcse-sup-large-rb":"princeton-nlp/sup-simcse-roberta-large"
        }
        
        ctokenizer = AutoTokenizer.from_pretrained(model_mapping[model_name])
        cmodel = AutoModel.from_pretrained(model_mapping[model_name])
        qtokenizer = AutoTokenizer.from_pretrained(model_mapping[model_name])
        qmodel = AutoModel.from_pretrained(model_mapping[model_name])

        
        
    if model_name in ["t5","flan-t5","unifiedqa"]:
        # encoder vs. decoder
        model_mapping = {
            "t5":"google-t5/t5-large",
            "flan-t5": "google/flan-t5-large",
            "unifiedqa": "allenai/unifiedqa-v2-t5-large-1251000"
        }
        
        qtokenizer = T5Tokenizer.from_pretrained(model_mapping[model_name])
        ctokenizer = T5Tokenizer.from_pretrained(model_mapping[model_name])
        qmodel = T5ForConditionalGeneration.from_pretrained(model_mapping[model_name]).decoder #decoder
        cmodel = T5ForConditionalGeneration.from_pretrained(model_mapping[model_name]).encoder #decoder

    qmodel.eval() 
    cmodel.eval() 
    
    for k,v in datasets.items():
        
        print("we are working on:",k)
        corpus_scores = []
        
        queries = v["queries"]
        source_queries = v["source_queries"]
        perspectives = v["perspectives"]
        corpus = v["corpus"]
     
        query_embeddings = create_embeddings(qtokenizer, qmodel, queries)
        source_query_embeddings = create_embeddings(qtokenizer, qmodel, source_queries)
        perspectives_embeddings = create_embeddings(qtokenizer, qmodel, perspectives)
        corpus_embeddings = create_embeddings(ctokenizer, cmodel, corpus)

        names = ["queries","source_queries","perspectives","corpus"]
        embs = [query_embeddings, source_query_embeddings, perspectives_embeddings, corpus_embeddings]

        for i, name in enumerate(names):
            path = "./embs/"+k+"_"+model_name+"_"+name+".json"
            
            with open(path,"w",encoding="utf-8") as f:
                json.dump(embs[i],f)
    
model_lists = ["dpr","dpr-multiset"]
model_lists.extend(["simcse-unsup","simcse-sup","simcse-unsup-large","simcse-sup-large","simcse-unsup-rb","simcse-sup-rb","simcse-unsup-large-rb","simcse-sup-large-rb"])


# for model_name in model_lists:
#     print(model_name)
#     extract_embeddings(datasets, model_name)
    

In [None]:
# the computation below is for demo-purpose

def general_pir_main(datasets, model_name="simcse-sup",mode="vec_cast"):

    for k,v in datasets.items():
        
        print("we are working on:",k)
        
        corpus_scores = []

        queries = v["queries"]
        query_labels = v["query_labels"]
        corpus = v["corpus"]
        key_ref = v["key_ref"]
        perspectives = v["perspectives"]
        
        if mode == "bm25_ranking":
            tokenized_corpus = [doc.split(" ") for doc in corpus]
            bm25 = BM25Okapi(tokenized_corpus) 
        
        embs_collection = {}
        
        for name in ["queries","source_queries","perspectives","corpus"]:
            with open("./embs/"+k+"_"+model_name+"_"+name+".json","r",encoding="utf-8") as f:
                embs_collection[name] = json.load(f)
            
        query_embeddings = embs_collection["queries"]
        source_query_embeddings = embs_collection["source_queries"]
        perspectives_embeddings = embs_collection["perspectives"]
        corpus_embeddings = embs_collection["corpus"]
        

        for index in trange(len(query_embeddings)):
            emb_q = query_embeddings[index]
            emb_s = source_query_embeddings[index]
            emb_p = perspectives_embeddings[index]
            
            scores = []
            bm25_scores = []
            p_scores = []
            
            for i, emb_c in enumerate(corpus_embeddings):
                
                # vector manipulation: aug denotes using q, instead of s
                if mode == "vec_projection_rev":
                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    if k == "ambigqa":
                        emb_p = emb_q - emb_s
                    weight = np.dot(emb_q, emb_p)/np.dot(emb_p, emb_p)
                    context_score = 1 - cosine(emb_q + weight*emb_p, emb_c)
                    scores.append(context_score)
                    
                if mode == "vec_dual_projection_rev":
                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    if k == "ambigqa":
                        emb_p = emb_q - emb_s
                    p_cor = np.dot(emb_p, emb_p)
                    weight_q = np.dot(emb_q, emb_p)/p_cor
                    weight_c = np.dot(emb_c, emb_p)/p_cor
                    context_score = 1 - cosine(emb_q + weight_q*emb_p, emb_c + weight_c*emb_p)
                    scores.append(context_score)
                    
                if mode == "vec_projection":
                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    if k == "ambigqa":
                        emb_p = emb_q - emb_s
                    weight = np.dot(emb_q, emb_p)/np.dot(emb_p, emb_p)
                    context_score = 1 - cosine(emb_q - weight*emb_p, emb_c)
                    scores.append(context_score)
                    
                if mode == "vec_dual_projection":
                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    if k == "ambigqa":
                        emb_p = emb_q - emb_s
                    p_cor = np.dot(emb_p, emb_p)
                    weight_q = np.dot(emb_q, emb_p)/p_cor
                    weight_c = np.dot(emb_c, emb_p)/p_cor
                    context_score = 1 - cosine(emb_q - weight_q*emb_p, emb_c- weight_c*emb_p)
                    scores.append(context_score)
                
                if mode == "vec_add":
                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    context_score = 1 - cosine(emb_s+emb_p, emb_c)
                    scores.append(context_score)

                if mode == "vec_concat":
                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    emb_sp = np.concatenate((emb_s, emb_p), axis=None)
                    emb_cp = np.concatenate((emb_c, emb_p), axis=None)
                    context_score = 1 - cosine(emb_sp, emb_cp)
                    scores.append(context_score)
                    
                if mode == "vec_aug_add":
                    emb_q,emb_c,emb_p = np.array(emb_q),np.array(emb_c),np.array(emb_p)
                    context_score = 1 - cosine(emb_q+emb_p, emb_c)
                    scores.append(context_score)
                    
                if mode == "vec_aug_concat":
                    emb_q,emb_c,emb_p = np.array(emb_q),np.array(emb_c),np.array(emb_p)
                    emb_qp = np.concatenate((emb_q, emb_p), axis=None)
                    emb_cp = np.concatenate((emb_c, emb_p), axis=None)
                    context_score = 1 - cosine(emb_qp, emb_cp)
                    scores.append(context_score)
                    
                if mode == "vec_dual_concat":
                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    emb_sp = np.concatenate((emb_s, emb_p), axis=None)
                    emb_cp = np.concatenate((emb_c, emb_p), axis=None)
                    context_score = 1 - cosine(emb_sp, emb_cp)
                    scores.append(context_score)
                    
                if mode == "vec_cast_single":
                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    if k == "ambigqa":
                        emb_p = emb_q - emb_s
                    context_score = 1 - cosine(emb_s-emb_p, emb_c)
                    scores.append(context_score)
                    
                if mode == "vec_aug_cast_single":
                    emb_q,emb_c,emb_p = np.array(emb_q),np.array(emb_c),np.array(emb_p)
                    if k == "ambigqa":
                        emb_p = emb_q - emb_s
                    context_score = 1 - cosine(emb_q-emb_p, emb_c)
                    scores.append(context_score)     
                    
                if mode == "vec_cast":
                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    if k == "ambigqa":
                        emb_p = emb_q - emb_s
                    context_score = 1 - cosine(emb_s-emb_p, emb_c-emb_p)
                    scores.append(context_score)
                    
                if mode == "vec_aug_cast":
                    emb_q,emb_c,emb_p = np.array(emb_q),np.array(emb_c),np.array(emb_p)
                    if k == "ambigqa":
                        emb_p = emb_q - emb_s
                    context_score = 1 - cosine(emb_q-emb_p, emb_c-emb_p)
                    scores.append(context_score)
                    
                    
                # score manipulation
                if mode == "additive":
                    context_score = 1 - cosine(emb_s, emb_c)
                    perspective_score = 1 - cosine(emb_p, emb_c)
                    scores.append(context_score + perspective_score)
                    
                if mode == "additive_aug":
                    context_score = 1 - cosine(emb_q, emb_c)
                    perspective_score = 1 - cosine(emb_p, emb_c)
                    scores.append(context_score + perspective_score)
                    
                if mode == "additive_tripple":
                    source_score = 1 - cosine(emb_s, emb_c)
                    context_score = 1 - cosine(emb_q, emb_c)
                    perspective_score = 1 - cosine(emb_p, emb_c)
                    scores.append(source_score + context_score + perspective_score)
                    
                
                # re-ranking
                if mode == "re-ranking":
                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    q_score = 1 - cosine(emb_s, emb_c)
                    context_score = 1 - cosine(emb_p, emb_c)
                    scores.append(q_score)
                    p_scores.append(context_score)

                if mode == "re-ranking_aug":
                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    q_score = 1 - cosine(emb_q, emb_c)
                    context_score = 1 - cosine(emb_p, emb_c)
                    scores.append(q_score)
                    p_scores.append(context_score)
                
                if mode == "bm25_ranking":
                    query = queries[index]
                    context = corpus[index]
                    kws = perspectives[index].split(" ")

                    emb_s,emb_c,emb_p = np.array(emb_s),np.array(emb_c),np.array(emb_p)
                    context_score = 1 - cosine(emb_s-emb_p, emb_c-emb_p)
                    scores.append(context_score)

                    tokenized_query = query.split(" ")
                    doc_scores = bm25.get_scores(tokenized_query)
                    bm25_scores.append(doc_scores[i])
            
            if mode == "re-ranking" or mode == "re-ranking_aug":
                thresh = float(np.percentile(np.array(scores), 80))
                temp_scores = []
                for i, score in enumerate(scores):
                    if score >= thresh:
                        temp_score = score + p_scores[i]
                    else:
                        temp_score = -100
                scores = temp_scores
            
            if mode == "bm25_ranking":
                sm = statistics.mean(scores)
                bm25m = statistics.mean(bm25_scores)
                temp_scores = []
                for i, score in enumerate(scores):
                    temp_scores = score-sm+bm25_scores[i]-bm25m
                scores = temp_scores
                    
            corpus_scores.append(scores)
            
        with open("./pir_scores/pir_"+k+"_"+mode+"_scores.json","w",encoding="utf-8") as f:
            json.dump(corpus_scores,f)
            
        evaluation(key_ref, corpus_scores, query_labels, k)
        print()

all_pir_mode = []

# all_pir_mode.extend(["vec_add","vec_concat","vec_aug_add","vec_aug_concat","vec_dual_concat","vec_cast","vec_aug_cast"])
# all_pir_mode.extend(["additive","additive_aug","additive_tripple"])
# all_pir_mode.extend(["vec_cast_single","vec_aug_cast_single", "re-ranking", "re-ranking_aug","bm25_ranking"])

# all_pir_mode.extend(["vec_projection","vec_dual_projection","vec_cast","vec_aug_cast"])

# all_pir_mode.extend(["vec_projection_rev","vec_dual_projection_rev"])
# for mode in all_pir_mode:
#     print("-------------------",mode,"--------------------------")
#     general_pir_main(datasets,model_name="simcse-sup",mode=mode)

In [None]:
# process the PIR scores computed