**Documentation**
This is a modified version of MUF that can be used for generalization for document retrieval. 

The only thing that needs to be ensured is that the dataset file you are trying to use is cleaned properly beforehand. Firstly, it needs to be stored in a json file as a dictionary with the following keys: question, text, map. The question key will contain a list of questions, the text key will contain a list of context passages, and the map key will be a dictionary mapping each individual question number to the specific text passages where it's answer is located. See examples of properly formatted dataset files in Noah's folder under the data subfolder. 

In [None]:
from rank_bm25 import BM25Okapi
from tqdm import tqdm
from nltk.tokenize import word_tokenize
import os

import numpy as np
import json
import torch
import gzip
import matplotlib.pyplot as plt

from transformers import DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoderTokenizer
from transformers import DPRQuestionEncoder
from transformers import DPRContextEncoder
import csv
from transformers import BertModel, BertTokenizer, BertTokenizerFast
from torch.nn import CosineSimilarity
from torch.nn import Softmax
from torch.utils.data import DataLoader, Dataset
from nltk import word_tokenize
import pandas as pd
import random

from IPython import embed
from sklearn.metrics import classification_report
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
torch.cuda.empty_cache()
print(device)

In [None]:
def embed_phrases(data, token_max_len=256):
    '''
    Takes as input a list of strings (whether they be questions, contexts etc.) and tokenizes/embeds them
    '''
    with torch.no_grad():
        context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
        context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to(device)
    phrase_embeds = torch.empty((1, 769), dtype=torch.float).to(device)
    for i in tqdm(range(len(data))):
        passage_num = torch.tensor([[data[i][0]]]).to(device)
        with torch.no_grad():
            tokenized = context_tokenizer(data[i][1], padding='max_length', max_length = token_max_len,truncation=True)
            batch_embeds = context_encoder( torch.tensor([tokenized['input_ids']]).to(device) )[0]
            final_val = torch.cat((passage_num, batch_embeds), 1).to(device)
        phrase_embeds = torch.cat((phrase_embeds, final_val), 0).to(device)
    return phrase_embeds[1:, :].cpu()

def embed_contexts(data):
    '''Embeds context passages in data'''
    with torch.no_grad():
        context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')#.to(device)
        context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to(device)
    phrase_embeds = torch.empty((1, 768), dtype=torch.float).to(device)
    for i in tqdm(range(len(data))):
        with torch.no_grad():
            tokenized = context_tokenizer(data[i], padding='max_length', max_length = 512,truncation=True)
            batch_embeds = context_encoder( torch.tensor([tokenized['input_ids']]).to(device) )[0]
        phrase_embeds = torch.cat((phrase_embeds, batch_embeds), 0).to(device)
    return phrase_embeds[1:, :].cpu()

def embed_questions(data):
    '''Embeds question passages in data'''
    with torch.no_grad():
        question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')#.to(device)
        question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base').to(device)
    phrase_embeds = torch.empty((1, 768), dtype=torch.float).to(device)
    for i in tqdm(range(len(data))):
        with torch.no_grad():
            tokenized = question_tokenizer(data[i], padding='max_length', max_length = 512,truncation=True)
            batch_embeds = question_encoder( torch.tensor([tokenized['input_ids']]).to(device) )[0]
        phrase_embeds = torch.cat((phrase_embeds, batch_embeds), 0).to(device)
    return phrase_embeds[1:, :].cpu()

def phrase_creator(data, phrase_len):
    '''
    Given compiled data takes contexts passages and divides them into phrase passages of a given length. Each phrase
    is formatted as a tuple, index 0 representing the context passage num it's from, index 1 being the actual phrase
    '''
    punc = ['.', '!', '?']
    phrases = []
    for i in range(len(data)):
        context_sents = data[i]
        batch = ""
        punc_count = 0
        punc_diff = 0
        for j in range(len(context_sents)):
            if context_sents[j] in ['.', '!', '?']:
                if punc_diff > 120: 
                    punc_count += 1
                if punc_count >= phrase_len:
                    punc_count = 0
                    phrases.append((i, batch))
                    batch = ""
                punc_diff = 0
                continue
            batch = batch + context_sents[j]
            punc_diff += 1
        phrases.append((i, batch))
    return phrases

In [None]:
#ALL DATASET CLASSES USED in MUF, there are distinct ones for DPR and Dense Phrase Retrieval 
class DPR_Dataset(Dataset):
    def __init__(self, context_file, question_file, idxs, NQ, map):
        self.context_embeds = context_file
        self.question_embeds = question_file
        self.context_embeds = []
        self.idxs = idxs
        if not NQ:
            self.map = map
        else:
            self.map = {}
            for i in range(len(idxs)):
                self.map[str(i)] = i
    
    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, idx):
        batch_ixs = self.idxs[idx]
        query_embed = self.question_embeds[idx ]
        context_embeds = [self.context_embeds[int(batch_ixs[i])]  for i in range(len(batch_ixs))  ]
        label = batch_ixs[0]
        return (query_embed, context_embeds, batch_ixs[:], label)
        
class Phrase_Dataset(Dataset):
    def __init__(self, context_file, question_file, idxs, NQ, map):
        self.context_embeds = context_file
        self.question_embeds = question_file
        self.context_embeds = {} #define context vals as dict with keys as passage nums and values in entries as passage encodings
        self.idxs = idxs
        if not NQ:
            self.map = map
        else:
            self.map = {}
            for i in range(len(idxs)):
                self.map[str(i)] = i

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, idx):
        batch_ixs = self.idxs[idx]
        query_embed = self.question_embeds[idx]
        context_embeds = []
        for i in range(len(batch_ixs)):
            ind = batch_ixs[i]
            for elem in self.context_embeds[int(ind)]:
                context_embeds.append(elem)
        label = self.map[str(idx)]
        return (query_embed, context_embeds, batch_ixs[:], label)

    def collate_fn_(self, data):
        q_batch = [q[0] for q in data]
        c_batch = [c[1] for c in data]
        ixs_batch = [i[2] for i in data]
        label_batch = [l[3] for l in data]
        max_len = max([len(c_batch[j]) for j in range(len(c_batch))])
        default = torch.full((1, 769), 0)
        for elem in c_batch:
            while len(elem) < max_len:
                elem.append(default)
        return q_batch, c_batch, ixs_batch, label_batch


In [None]:
test_len = 10
def conf_linear(similarities):
    if len(similarities) < test_len:
        elems = np.sort(similarities)
    else:
        elems = np.sort(similarities)[-test_len:]
    conf = 0
    for i in range(-2, -test_len, -1):
        conf += elems[-1]-elems[i]
    return conf/8

def softmax_func(elems):
    return np.exp(elems)/np.sum(np.exp(elems))

def softmax_grad(first, m, elems, T):
    '''
    first variable represents the value in the numerator of the softmax function
    m is the exponent of the first variable
    elems are the original logit values
    '''
    grad = 0
    div_elems = elems/T
    p = np.exp(div_elems)
    den = np.sum(p)
    s1 = 0
    for i in range(len(elems)):
        s1 += elems[i]*p[i]
    s1 *= first/(T*den)**2
    s2 = m*first/((T**2)*den)
    return s1-s2

def conf_softmax(similarities, T=1):
    '''
    This confidence measure includes calirbation of the logits with temperature scaling
    '''
    if len(similarities) < test_len:
        elems = np.sort(similarities)
    else:
        elems = np.sort(similarities)[-test_len:]
    div_elems = elems/T
    calb_vals = softmax_func(div_elems)
    grad = 0
    conf = calb_vals[-1]
    grad = softmax_grad(calb_vals[-1], elems[-1], elems[:], T)
    # for i in range(-2, -test_len, -1):
    #     conf += calb_vals[-1]-calb_vals[i]
    #     grad += softmax_grad(calb_vals[-1], elems[-1], elems[:], T)-softmax_grad(calb_vals[i], elems[i], elems, T)
    return conf, grad

def run_DensePhraseRet(query_embed, context_embeds, ixs, confidence_func, T=1):
    similarities = []
    for i in range(len(context_embeds)):
        context_val = context_embeds[i][0][1:]
        similarities.append( (torch.dot(query_embed[0], context_val)/torch.norm(context_val)).to(float).cpu())
    similarities = np.array(similarities)
    similarities = np.where(np.isnan(similarities), 0, similarities)
    return int(context_embeds[np.argmax(np.array(similarities))][0][0]), confidence_func(similarities, T)

def run_DPR(query_embed, context_embeds, ixs, confidence_func, T=1):
    similarities = []
    for i in range(len(context_embeds)):
        context_val = context_embeds[i][0]
        similarities.append( (torch.dot(query_embed[0], context_val)/torch.norm(context_val )).to(float).cpu())
    similarities = np.array(similarities)
    return int(ixs[np.argmax(similarities)]), confidence_func(similarities, T)


def update_batch(preds, batch_acc, label):
    for i in range(len(preds)):
        if preds[i] == label:
            batch_acc.append(1)
        else:
            batch_acc.append(0)

def update_bins(bins, preds, confs, increments, label):
    for j in range(len(confs)):
        c = confs[j][0]
        for i in range(1, increments+1):
            if c <= i*1/increments:
                bins[i*1/increments]['grads'].append(confs[j][1])
                bins[i*1/increments]['confs'].append(c)
                if preds[j] == label: bins[i*1/increments]['acc'] += 1
                break

def calc_bin_loss(bins):
    loss = 0
    grad = 0
    N = 0
    for k in bins.keys():
        N += len(bins[k]['confs'])
    for t in bins.keys():
        if len(bins[t]['confs']) == 0: continue
        avg_conf = sum(bins[t]['confs'])/len(bins[t]['confs'])
        avg_grad = sum(bins[t]['grads'])/len(bins[t]['grads'])
        acc = bins[t]['acc']/len(bins[t]['confs'])
        grad += len(bins[t]['confs'])/N*(avg_conf-acc)*avg_grad
        loss += len(bins[t]['confs'])/N*(avg_conf-acc)**2
    return loss, grad


In [None]:
#MUF Evaluation Function
#REGULAR NO BATCH MUF
def eval(T_init, rows, datasets):
    phrase_lens = [0, 1, 3, 5]
    y_pred = []
    y_label = []
    batch_conf = []
    batch_acc = []
    conf_func = conf_softmax
    temp = T_init
    step = 100
    increments = 10
    bins = {i/increments: {"confs": [], "acc": 0, "grads": []} for i in range(1, increments+1)}
    for i in tqdm(range(len(rows))):
        # if i == 2: break
        preds = []
        for d in range(4):
            query, context, ixs, label = datasets[d][i]
            if phrase_lens[d] == 0:
                preds.append(run_DPR(query, context, ixs, conf_func, temp))
            else:
                preds.append(run_DensePhraseRet(query, context, ixs, conf_func, temp))
        confs = np.array([elem[1][0] for elem in preds])
        conf_grads = np.array([en[1][1] for en in preds])
        np.where(np.isnan(conf_grads), 0, conf_grads)
        np.where(np.isnan(confs), 0, confs)
        cum_data = np.array(list(zip(confs, conf_grads)))
        confs = list(confs)
        pred = preds[confs.index(max(confs))][0]
        update_bins(bins, preds, cum_data, increments, label)
        y_pred.append(int(pred))
        y_label.append(int(label))
        if i%int(len(rows)/10) == 0 and i != 0:
            loss, grad = calc_bin_loss(bins)
            temp += float(step*grad)
            bins = {i/increments: {"confs": [], "acc": 0, "grads": []} for i in range(1, increments+1)}

    return classification_report(y_label, y_pred, digits=4, output_dict=True)['macro avg']['f1-score']
        
        

In [None]:
#Generating Question, Context, Phrase Embeds, then generates dataset

data_file = "" #CHANGE THIS TO FILE OF DATA YOU WANT TO PROCESS
with open(data_file) as file: 
    data = json.load(file)
context_embeds = embed_contexts(data['text']) 
question_embeds = embed_questions(data['question'])
phrases = [phrase_creator(data['text'], i) for i in [1, 3, 5]]
phrase_embeds = [embed_phrases(elem) for elem in phrases]
idxs = []
m = [i for i in range(len(data['question']))]
for k in range(len(data['question'])):
    z = m[:]
    z[k] = 0
    z[0] = k
    idxs.append(z)

datasets = []
for j in range(4):
    if j == 0:
        datasets.append(DPR_Dataset(context_embeds, question_embeds, idxs, data['map']))
    else:
        datasets.append(Phrase_Dataset(phrase_embeds[j-1], question_embeds, idxs, data['map']))


In [None]:
eval(0.1, idxs, datasets)