In [1]:
from rank_bm25 import BM25Okapi
from tqdm import tqdm
from nltk.tokenize import word_tokenize
import os
import numpy as np
import json
from bs4 import BeautifulSoup
import torch
import gzip
import torch.nn as nn
from collections import OrderedDict

from transformers import DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoderTokenizer
from transformers import DPRQuestionEncoder
from transformers import DPRContextEncoder
import csv
from transformers import BertModel, BertTokenizer, BertTokenizerFast
from torch.nn import CosineSimilarity
from torch.utils.data import DataLoader, Dataset
from nltk import word_tokenize
import pandas as pd
import random

from IPython import embed
from sklearn.metrics import classification_report
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
# device = "cpu"
torch.cuda.empty_cache()
# os.environ['CUDA_VISIBLE_DEVICES'] = "4,5,6,7"

In [2]:
def kmax(k, nums, blocked):
    ix = []
    while len(ix) < k:
        w = max(nums)
        w_ix = nums.tolist().index(w)
        nums[w_ix] = float('-inf')
        if w_ix != blocked:
            ix.append(w_ix)
    return ix

def sample(j, nums, blocked):
    rand = random.sample(list(range(len(nums))), 30)
    true = []
    for r in rand:
        if r not in blocked:
            true.append(r)
    return true[0:20]
    
def generate_ixs(examples, file_out, ix_map=[]):
    file = open(file_out, 'w+', newline ='\n')
    k_hard = 10
    j_rand = 20
    bm25 = BM25Okapi([word_tokenize(a) for a in examples['text']])
    for i in tqdm(range(len(examples['question']))):
        tokenized_query = examples['question'][i]
        doc_scores = bm25.get_scores(word_tokenize(tokenized_query))
        if ix_map:
            block = ix_map[str(i)]
        else:
            block = i
        hard_ix = kmax(k_hard, doc_scores, block)
        rand_ix = sample(j_rand, doc_scores, hard_ix + [block])
        passage_num = examples['map'][str(i)]
        answers = [[passage_num] + hard_ix + rand_ix]
        write = csv.writer(file)
        write.writerows(answers)
    file.close()
    
    


In [3]:
def preprocess(file_name, file_out):
    with open(file_name) as file: 
        data = json.load(file)
    data = data['data']
    example = {'question': [], 'text': [], 'map': {}, 'answers': [], 'title': {}}
    question_count = 0
    text_count = 0
    for par in data:
        title = " ".join((par['title'].split("_")))
        par = par['paragraphs']
        for elem in par:
            for chunk in elem['qas']:
                if chunk['is_impossible']: 
                    if len(chunk['plausible_answers']) == 0: continue
                    example['question'].append(chunk['question'])
                    #example['answers'].append(chunk['plausible_answers'][0])
                else:
                    example['question'].append(chunk['question'])
                    #example['answers'].append(chunk['answers'][0])
                example['map'][question_count] = text_count
                question_count += 1
            example['text'].append(title+ " " + elem['context'])
            example['title'][text_count] = title 
            text_count += 1
    with open(file_out, 'w+') as outfile:
        json.dump(example, outfile)
    return example
    
    

In [4]:
class DPR(torch.nn.Module):

    def __init__(self, context_encoder,  question_encoder, pretrained=True):
        super().__init__()
        self.question_encoder =  question_encoder
        self.context_encoder = context_encoder#nn.DataParallel(context_encoder).to(device)
        #self.context_head = torch.nn.Linear(1024, 2048)
        #self.question_head = torch.nn.Linear(1024, 1024)
        #self.tanh = torch.nn.Tanh()
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.log_softmax = torch.nn.LogSoftmax(dim=1)
        self.pretrained = pretrained

    def forward(self, context_batch, question_batch):
        #Context_batch (shape) : batchsize x seq_lenth : (questino_batch (shape))
        cb = [torch.tensor(i).to(device) for i in context_batch]
        #embeddings_context = self.context_encoder(*cb).last_hidden_state
        if not self.pretrained:
            embeddings_context = torch.stack([self.context_encoder(i).last_hidden_state[:,0,:] for i in cb])
            embedding_question = self.question_encoder(torch.tensor(question_batch).to(device)).last_hidden_state[:,0,:]
        else:
            #enc = self.context_encoder
            #embed()
            embeddings_context = torch.stack([self.context_encoder(i)['pooler_output'].to(device) for i in cb]).to(device)
            #embeddings_context = self.context_encoder(torch.tensor(context_batch).to(device))['pooler_output'].to(device)#torch.stack([self.context_encoder(i)['pooler_output'].to(device) for i in cb]).to(device)
            embedding_question = self.question_encoder(torch.tensor(question_batch).to(device))['pooler_output'].to(device) 
        #embed()
        logits = torch.einsum('ijk, ik ->ij', embeddings_context, embedding_question) 
        return logits

    def predict(self,batch):
        
        with torch.no_grad():
            question_batch = batch[0]
            context_batch = batch[1]
            context_ixs = batch[2]
            labels = batch[3]

            logits = self.log_softmax(self.forward(context_batch, question_batch))
            loss = self.loss_fn(logits.to(device), torch.tensor(labels).to(device))
            acc = sum([1 if logits.argmax(axis = -1)[i] == labels[i] else 0 for i in range(len(labels))]) / len(labels) 

        return logits.argmax(axis = -1), logits.max(axis=-1), loss, acc # predictions, loss


    def criterion(self, batch):#context_batch, question_batch, labels):
        question_batch = batch[0] 
        context_batch = batch[1]
        context_ixs = batch[2]
        labels = batch[3]

        #batch_size, _ , = labels.shape
        batch_size = len(context_batch)
        #change this to BCELoss with logits 
        #loss_func = torch.nn.BCEWithLogitsLoss()

        logits = self.forward(context_batch, question_batch)

        #loss = loss_func(logits.reshape(-1, 2), labels.reshape(-1).long())
        loss = self.loss_fn(logits.to(device), torch.tensor(labels).to(device))
        
        logging_output = {
            "loss": loss.item(), 
            "accuracy" : sum([1 if logits.argmax(axis = -1)[i] == labels[i] else 0 for i in range(len(labels))]) / batch_size   #(logits.argmax(axis = -1) == label).prod(axis = -1).sum() / batch_size
        }
        return loss, logging_output

In [12]:
ce = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to(device)
qe = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base').to(device)

def embed_phrases(data, token_max_len, model_file=None):
    '''
    Takes as input a list of strings (whether they be questions, contexts etc.) and tokenizes/embeds them
    '''
    if model_file is None:
        with torch.no_grad():
            context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
            context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to(device)
    else:
        model = DPR(ce, qe)
        l = torch.load(model_file)
        new_state_dict = OrderedDict()
        for k, v in l.items():
            mod = k.split('.')
            if mod[1] != "module":
                new_state_dict[k] = v
                continue
            del mod[1]
            k = '.'.join(mod)
            new_state_dict[k] = v
        model.load_state_dict(new_state_dict)
        context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')#.to(device)
        context_encoder = model.context_encoder
    phrase_embeds = torch.empty((1, 769), dtype=torch.float).to(device)
    for i in tqdm(range(len(data))):
        passage_num = torch.tensor([[data[i][0]]]).to(device)
        with torch.no_grad():
            tokenized = context_tokenizer(data[i][1], padding='max_length', max_length = token_max_len,truncation=True)
            batch_embeds = context_encoder( torch.tensor([tokenized['input_ids']]).to(device) )[0]
            final_val = torch.cat((passage_num, batch_embeds), 1).to(device)
        phrase_embeds = torch.cat((phrase_embeds, final_val), 0).to(device)
    return phrase_embeds[1:, :].cpu()

def embed_contexts(data, model_file=None):
    if model_file is None:
        with torch.no_grad():
            context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')#.to(device)
            context_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base').to(device)
    else:
        model = DPR(ce, qe)
        l = torch.load(model_file)
        new_state_dict = OrderedDict()
        for k, v in l.items():
            mod = k.split('.')
            if mod[1] != "module":
                new_state_dict[k] = v
                continue
            del mod[1]
            k = '.'.join(mod)
            new_state_dict[k] = v
        model.load_state_dict(new_state_dict)
        context_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')#.to(device)
        context_encoder = model.context_encoder
    phrase_embeds = torch.empty((1, 768), dtype=torch.float).to(device)
    for i in tqdm(range(len(data))):
        with torch.no_grad():
            tokenized = context_tokenizer(data[i], padding='max_length', max_length = 512,truncation=True)
            batch_embeds = context_encoder( torch.tensor([tokenized['input_ids']]).to(device) )[0]
        phrase_embeds = torch.cat((phrase_embeds, batch_embeds), 0).to(device)
    return phrase_embeds[1:, :].cpu()

def embed_questions(data, model_file=None):
    if model_file is None:
        with torch.no_grad():
            question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')#.to(device)
            question_encoder = DPRQuestionEncoder.from_pretrained('facebook/dpr-question_encoder-single-nq-base').to(device)
    else:
        model = DPR(ce, qe)
        l = torch.load(model_file)
        new_state_dict = OrderedDict()
        for k, v in l.items():
            mod = k.split('.')
            if mod[1] != "module":
                new_state_dict[k] = v
                continue
            del mod[1]
            m = '.'.join(mod)
            new_state_dict[m] = v
        model.load_state_dict(new_state_dict)
        question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained('facebook/dpr-question_encoder-single-nq-base')#.to(device)
        question_encoder = model.question_encoder
    phrase_embeds = torch.empty((1, 768), dtype=torch.float).to(device)
    for i in tqdm(range(len(data))):
        with torch.no_grad():
            tokenized = question_tokenizer(data[i], padding='max_length', max_length = 512,truncation=True)
            batch_embeds = question_encoder( torch.tensor([tokenized['input_ids']]).to(device) )[0]
        phrase_embeds = torch.cat((phrase_embeds, batch_embeds), 0).to(device)
    return phrase_embeds[1:, :].cpu()

def phrase_creator(data, titles=[], phrase_len=[]):
    punc = ['.', '!', '?']
    phrases = []
    for i in range(len(data)):
        context_sents = data[i]
        batch = ""
        punc_count = 0
        punc_diff = 0
        for j in range(len(context_sents)):
            if context_sents[j] in ['.', '!', '?']:
                if punc_diff > 120: 
                    punc_count += 1
                if punc_count >= phrase_len:
                    punc_count = 0
                    phrases.append((i, batch))
                    batch = ""
                punc_diff = 0
                continue
            batch = batch + context_sents[j]
            punc_diff += 1
        phrases.append((i, batch))
    return phrases

    



        

Some weights of the model checkpoint at facebook/dpr-ctx_encoder-single-nq-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.weight', 'ctx_encoder.bert_model.pooler.dense.bias']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the 

In [6]:
def store_vals(embeds, q_or_c, dev, phrase_len, data_set):
    if q_or_c and dev:
        torch.save(embeds, f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/{data_set}/dev-questions")
    elif q_or_c and not dev:
        torch.save(embeds, f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/{data_set}/train-questions")
    elif not q_or_c and dev:
        torch.save(embeds, f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/{data_set}/dev-contexts-{phrase_len}")
    else:
        torch.save(embeds, f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/{data_set}/train-contexts-{phrase_len}")


In [9]:
#FOR SQUAD
data_set = "squad"
file_path = "/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/squad/dev-clean.json"
with open(file_path) as file: 
    data = json.load(file)

In [7]:
#FOR PUBMED
data_set = "pubmed-flex"
file_path = "/home/ubuntu/nlm/noah/pubmed/small.json"
with open(file_path) as file: 
    data = json.load(file)

In [7]:
#FOR SCOTUS
data_set = "scotus-flex"
file_path = "/home/ubuntu/nlm/noah/scotus/dev.json"
with open(file_path) as file: 
    data = json.load(file)



In [43]:
#FOR NFCORPUS
data_set = "nfcorpus-flex"
file_path = "/home/ubuntu/nlm/noah/nfcorpus/dev-clean.json"
with open(file_path) as file: 
    data = json.load(file)

In [8]:
#FOR P_RANK
data_set = "p_rank"
file_path = "/home/ubuntu/nlm/noah/p_rank/dset.csv"
data = {"text": [], "question": [], "map": {}}
first = True
with open(file_path, newline='') as idx_file:
    for line in csv.reader(idx_file):
        if first:
            first = False
            continue
        row = list(line)
        data['question'].append(row[1])
        data['text'].append(row[2])
        data['map'][int(row[0])] = int(row[0])


In [5]:
file_path = "/home/ubuntu/nlm/noah/squad/dev_raw.json"
file_out = "/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/squad/dev-clean.json"
data = preprocess(file_path, file_out)
#phrases = phrase_creator(data['text'], 3)

In [9]:
#embed and store questions
embeds = embed_questions(data['question'], "/home/ubuntu/nlm/noah/pubmed__meta.pt")
store_vals(embeds, True, True, 0, data_set)

100%|██████████| 1000/1000 [01:08<00:00, 14.64it/s]


In [10]:
#embed and store contexts
c_embeds = embed_contexts(data['text'], "/home/ubuntu/nlm/noah/models/scotus_in_domain.pt")
store_vals(c_embeds, False, True, 0, data_set)

question_encoder.question_encoder.bert_model.embeddings.position_ids
question_encoder.question_encoder.bert_model.embeddings.word_embeddings.weight
question_encoder.question_encoder.bert_model.embeddings.position_embeddings.weight
question_encoder.question_encoder.bert_model.embeddings.token_type_embeddings.weight
question_encoder.question_encoder.bert_model.embeddings.LayerNorm.weight
question_encoder.question_encoder.bert_model.embeddings.LayerNorm.bias
question_encoder.question_encoder.bert_model.encoder.layer.0.attention.self.query.weight
question_encoder.question_encoder.bert_model.encoder.layer.0.attention.self.query.bias
question_encoder.question_encoder.bert_model.encoder.layer.0.attention.self.key.weight
question_encoder.question_encoder.bert_model.encoder.layer.0.attention.self.key.bias
question_encoder.question_encoder.bert_model.encoder.layer.0.attention.self.value.weight
question_encoder.question_encoder.bert_model.encoder.layer.0.attention.self.value.bias
question_encoder

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.
100%|██████████| 1000/1000 [01:08<00:00, 14.61it/s]


In [14]:
#embed phrases
phrase_len = 5
phrases = phrase_creator(data['text'], False, phrase_len)
p_embeds = embed_phrases(phrases, 400, "/home/ubuntu/nlm/noah/models/scotus_in_domain.pt")
store_vals(p_embeds, False, True, phrase_len, data_set)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.
100%|██████████| 1500/1500 [01:59<00:00, 12.54it/s]


In [7]:
file_out = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/{data_set}/dev_ixs.csv"
generate_ixs(data, file_out, data['map'])

100%|██████████| 11858/11858 [01:16<00:00, 154.15it/s]


In [39]:
print(phrases[28000])

(266, " For the State there are also advantages \x97the more promptly imposed punishment after an admission of guilt may more effectively attain the objectives of punishment; and with the avoidance of trial, scarce judicial and prosecutorial resources are conserved for those cases in which there is a substantial issue of the defendant's guilt or in which there is substantial doubt that the State can sustain its burden of proof")


In [1]:
def forward(self, context_batch, question_batch):
        #Context_batch (shape) : batchsize x seq_lenth : (questino_batch (shape))
        batch_size = 20
        cb = [torch.tensor(i).to(device) for i in context_batch]
        embeddings_context = []
        #embeddings_context = self.context_encoder(*cb).last_hidden_state
        for i in range(0, len(cb), batch_size):
            end = min(len(cb), i+batch_size)
            c_batch = cb[i:end]
            if not self.pretrained:
                embeddings_context = embeddings_context + [a.last_hidden_state[:,0,:] for a in self.context_encoder(c_batch)]
            else:
                embeddings_context = embeddings_context + [a['pooler_output'] for a in self.context_encoder(c_batch)]
        embeddings_context = torch.stack(embeddings_context)
        if not self.pretrained:
            embedding_question = self.question_encoder(torch.tensor(question_batch).to(device)).last_hidden_state[:,0,:]
        else:
            embedding_question = self.question_encoder(torch.tensor(question_batch).to(device))['pooler_output']
        logits = torch.einsum('ijk, ik ->ij', embeddings_context, embedding_question) 
        return logits