In [1]:
from rank_bm25 import BM25Okapi
from tqdm import tqdm
from nltk.tokenize import word_tokenize

import numpy as np
import json
from bs4 import BeautifulSoup
import torch
import gzip

from transformers import DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoderTokenizer
from transformers import DPRQuestionEncoder
from transformers import DPRContextEncoder
import csv
from transformers import BertModel, BertTokenizer, BertTokenizerFast
from torch.nn import CosineSimilarity
from torch.utils.data import DataLoader, Dataset
from nltk import word_tokenize
import pandas as pd
import random

from IPython import embed
from sklearn.metrics import classification_report
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
print(device)
torch.cuda.empty_cache()
device = torch.device("cpu")

cuda


In [2]:
torch.cuda.empty_cache()

In [2]:
class SimpleDataset(Dataset):
    def __init__(self, context_file, question_file, idxs, NQ, map):
        context_vals = torch.load(context_file) 
        question_vals = torch.load(question_file)
        self.context_embeds = {}
        for elem in context_vals:
            val = torch.reshape(elem, (1, 769))
            if int(val[0][0].item()) not in self.context_embeds.keys():
                self.context_embeds[int(val[0][0].item())] = [val]
            else:
                self.context_embeds[int(val[0][0].item())].append(val)
        self.query_embeds = []
        for elem in question_vals:
            self.query_embeds.append(torch.reshape(elem, (1, 768)))
        self.ixs = idxs
        if not NQ:
            self.map = map
        else:
            self.map = {}
            for i in range(len(idxs)):
                self.map[str(i)] = i

    def __len__(self):
        return len(self.ixs)

    def __getitem__(self, idx):
        # query embed, context embeds, true label
        batch_ixs = self.ixs[idx]
        query_embed = self.query_embeds[idx]
        context_embeds = []
        for i in range(len(batch_ixs)):
            ind = batch_ixs[i]
            for elem in self.context_embeds[int(ind)]:
                context_embeds.append(elem)
        label = self.map[str(idx)]
        return ( query_embed, context_embeds, batch_ixs[:], label )
        
    def collate_fn_(self, data):
        q_batch = [q[0] for q in data]
        c_batch = [c[1] for c in data]
        ixs_batch = [i[2] for i in data]
        label_batch = [l[3] for l in data]
        max_len = max([len(c_batch[j]) for j in range(len(c_batch))])
        default = torch.full((1, 769), 0)
        for elem in c_batch:
            while len(elem) < max_len:
                elem.append(default)
        return q_batch, c_batch, ixs_batch, label_batch

In [3]:
def run_DPR(query_embed, context_embeds, ixs):
    similarities = []
    for i in range(len(context_embeds)):
        context_val = context_embeds[i][0][1:]
        similarities.append( torch.dot(query_embed[0], context_val)/torch.norm(context_val))
    similarities = np.array(similarities)
    similarities = np.where(np.isnan(similarities), 0, similarities)
    return int(context_embeds[np.argmax(np.array(similarities))][0][0])

In [4]:
def NQ_processing(idx_f):
    rows = []
    with open(idx_f, newline='') as idx_file:
        for line in csv.reader(idx_file):
            rows.append(list(line))
    idx_file.close()
    return rows
    

In [3]:
#For Dev Datasets
i = 0
phrase_len = 75
question_file = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/dev/question-{i}-embeds"
context_file = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/dev/context-{i}-{phrase_len}-embeds"
idx_file = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/dev/nq-dev-ix-{i}.csv"

In [23]:
#For Training Datasets
i = 0
phrase_len = 250
question_file = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/train/question-{i}-embeds"
context_file = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/train/context-{i}-{phrase_len}-embeds"
idx_file = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/train/nq-train-ix-{i}.csv"

In [31]:
#For SQUAD
phrase_len = 5
question_file = "/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/squad/dev-questions"
context_file = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/squad/dev-contexts-{phrase_len}"
idx_file = "/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/squad/dev_ixs.csv"


In [20]:
#FOR PUBMED
phrase_len = 1
question_file = "/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/pubmed-flex/dev-questions"
context_file = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/pubmed-flex/dev-contexts-{phrase_len}"
idx_file = "/home/ubuntu/nlm/noah/pubmed/small_ix.csv"

In [23]:
#FOR SCOTUS
length = 3
question_file = "/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/scotus-flex/dev-questions"
context_file = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/scotus-flex/dev-contexts-{length}"
idx_file = "/home/ubuntu/nlm/noah/scotus/dev_ix.csv"

In [9]:
#FOR NFCORPUS
length = 1
question_file = "/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/nfcorpus/dev-questions"
context_file = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/nfcorpus/dev-contexts-{length}"
idx_file = "/home/ubuntu/nlm/noah/nfcorpus/dev-clean-ix.csv"

In [21]:
#FOR NQ
rows = NQ_processing(idx_file)
dataset = SimpleDataset(context_file, question_file, rows, True, [])
dataloader = DataLoader(dataset, batch_size=30, shuffle=False, collate_fn=dataset.collate_fn_)

In [24]:
#FOR SQUAD/SCOTUS
rows = NQ_processing(idx_file)
file_name = "/home/ubuntu/nlm/noah/scotus/dev.json"
with open(file_name) as file: 
    data = json.load(file)
QC_map = data['map']
dataset = SimpleDataset(context_file, question_file, rows, False, QC_map)
dataloader = DataLoader(dataset, batch_size=30, shuffle=False, collate_fn=dataset.collate_fn_)

In [22]:
#Dense Phrase Retrieval No Batches
y_pred = []
y_label = []
for i in tqdm(range(len(rows))):
    query, context, ixs, label = dataset[i]
    pred = run_DPR(query, context, ixs)
    y_pred.append(int(pred))
    y_label.append(int(label))
    # print("PRED:", pred)
    # print("LABEL:", label)
print(classification_report(y_label, y_pred, digits=4))    
    

100%|██████████| 1000/1000 [00:04<00:00, 208.17it/s]

              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         1
           1     0.0000    0.0000    0.0000         1
           2     0.0000    0.0000    0.0000         1
           3     0.0833    1.0000    0.1538         1
           4     1.0000    1.0000    1.0000         1
           5     0.0000    0.0000    0.0000         1
           6     0.0000    0.0000    0.0000         1
           7     0.0000    0.0000    0.0000         1
           8     1.0000    1.0000    1.0000         1
           9     0.0000    0.0000    0.0000         1
          10     0.0000    0.0000    0.0000         1
          11     0.0000    0.0000    0.0000         1
          12     0.0000    0.0000    0.0000         1
          13     0.0000    0.0000    0.0000         1
          14     0.0000    0.0000    0.0000         1
          15     0.0000    0.0000    0.0000         1
          16     0.0000    0.0000    0.0000         1
          17     0.0000    


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
#Dense Phrase Retrieval Implementation
y_pred = []
y_label = []
for step, (query, context, ixs, label) in tqdm(enumerate(dataloader)): 
    context_embeds = []
    # if step == 2: break
    for j in range(len(context)):
        new = []
        for k in range(len(context[j])):
            new.append(context[j][k])
        context_embeds.append(torch.stack(new))
    context_embeds = torch.stack(context_embeds)
    inds = []
    for m in range(len(ixs)):
        new = []
        for n in range(len(ixs[m])):
            new.append(int(ixs[m][n]))
        inds.append(new)
    inds = torch.tensor(inds) 
    for i in range(len(query)): # this should be batch size
        #embed()
        question_embed = query[i] #.detach()
        #embed()
        pred = run_DPR(question_embed, context_embeds[i], inds[i])
        # print("PRED:", pred)
        # print("LABEL:", label[i])
        # print("choices:", inds[i])
        y_pred.append(int(pred))
        y_label.append(int(label[i]))

print(classification_report(y_label, y_pred, digits=4))
    

In [31]:
print(dataset.context_embeds[0][0].shape)

torch.Size([768])


In [8]:
def run_model(rows, dataset):
    y_pred = []
    y_label = []
    for i in tqdm(range(len(rows))):
        query, context, ixs, label = dataset[i]
        pred = run_DPR(query, context, ixs)
        y_pred.append(int(pred))
        y_label.append(int(label))
        # print("PRED:", pred)
        # print("LABEL:", label)
    print(classification_report(y_label, y_pred, digits=4)) 

i = 0
for phrase_len in [1, 3, 5]:
    print(phrase_len)
    question_file = "/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/pubmed/dev-questions"
    context_file = f"/home/ubuntu/nlm/williamyang/DPR_Preprocess_Data/pubmed/dev-contexts-{phrase_len}"
    idx_file = "/home/ubuntu/nlm/noah/pubmed/full_clean_small_ix.csv"
    rows = NQ_processing(idx_file)
    dataset = SimpleDataset(context_file, question_file, rows, True, [])
    run_model(rows, dataset)

    
   

1


100%|██████████| 1000/1000 [00:13<00:00, 72.53it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0     1.0000    1.0000    1.0000         1
           1     1.0000    1.0000    1.0000         1
           2     1.0000    1.0000    1.0000         1
           3     0.1250    1.0000    0.2222         1
           4     1.0000    1.0000    1.0000         1
           5     0.0000    0.0000    0.0000         1
           6     1.0000    1.0000    1.0000         1
           7     1.0000    1.0000    1.0000         1
           8     1.0000    1.0000    1.0000         1
           9     1.0000    1.0000    1.0000         1
          10     1.0000    1.0000    1.0000         1
          11     1.0000    1.0000    1.0000         1
          12     1.0000    1.0000    1.0000         1
          13     1.0000    1.0000    1.0000         1
          14     1.0000    1.0000    1.0000         1
          15     1.0000    1.0000    1.0000         1
          16     1.0000    1.0000    1.0000         1
          17     0.0000    

100%|██████████| 1000/1000 [00:04<00:00, 202.80it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0     1.0000    1.0000    1.0000         1
           1     1.0000    1.0000    1.0000         1
           2     1.0000    1.0000    1.0000         1
           3     0.1111    1.0000    0.2000         1
           4     1.0000    1.0000    1.0000         1
           5     0.0000    0.0000    0.0000         1
           6     1.0000    1.0000    1.0000         1
           7     0.0000    0.0000    0.0000         1
           8     1.0000    1.0000    1.0000         1
           9     1.0000    1.0000    1.0000         1
          10     1.0000    1.0000    1.0000         1
          11     1.0000    1.0000    1.0000         1
          12     1.0000    1.0000    1.0000         1
          13     1.0000    1.0000    1.0000         1
          14     1.0000    1.0000    1.0000         1
          15     1.0000    1.0000    1.0000         1
          16     0.0000    0.0000    0.0000         1
          17     0.0000    

100%|██████████| 1000/1000 [00:03<00:00, 315.16it/s]


              precision    recall  f1-score   support

           0     1.0000    1.0000    1.0000         1
           1     1.0000    1.0000    1.0000         1
           2     1.0000    1.0000    1.0000         1
           3     0.1111    1.0000    0.2000         1
           4     1.0000    1.0000    1.0000         1
           5     0.0000    0.0000    0.0000         1
           6     1.0000    1.0000    1.0000         1
           7     0.0000    0.0000    0.0000         1
           8     1.0000    1.0000    1.0000         1
           9     0.0000    0.0000    0.0000         1
          10     1.0000    1.0000    1.0000         1
          11     1.0000    1.0000    1.0000         1
          12     1.0000    1.0000    1.0000         1
          13     1.0000    1.0000    1.0000         1
          14     1.0000    1.0000    1.0000         1
          15     0.5000    1.0000    0.6667         1
          16     0.0000    0.0000    0.0000         1
          17     0.0000    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [19]:
print(dataset.context_embeds.keys())

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222

In [None]:
# class SimpleDataset(Dataset):
#     def __init__(self, context_file, question_file, idxs):
#         context_vals = torch.load(context_file) 
#         question_vals = torch.load(question_file)
#         self.context_embeds = {} #define context vals as dict with keys as passage nums and values in entries as passage encodings
#         for elem in context_vals:
#             val = torch.reshape(elem, (1, 769))
#             if int(val[0][0].item()) not in self.context_embeds.keys():
#                 self.context_embeds[int(val[0][0].item())] = [val]
#             else:
#                 self.context_embeds[int(val[0][0].item())].append(val)
#         self.question_embeds = []
#         for elem in question_vals:
#             self.question_embeds.append(torch.reshape(elem, (1, 768)))
#         self.idxs = idxs

#     def __len__(self):
#         return len(self.idxs)

#     def __getitem__(self, idx):
#         batch_ixs = self.idxs[idx]
#         query_embed = self.question_embeds[   int(batch_ixs[0])   ]
#         context_embeds = []
#         for ind in batch_ixs:
#             for elem in self.context_embeds[int(ind)]:
#                 context_embeds.append(elem)
#         label = idx
#         return (query_embed, context_embeds, batch_ixs[:], label)

#     def collate_fn_(self, data):
#         q_batch = [q[0] for q in data]
#         c_batch = [c[1] for c in data]
#         ixs_batch = [i[2] for i in data]
#         label_batch = [l[3] for l in data]
#         max_len = max([len(c_batch[j]) for j in range(len(c_batch))])
#         default = torch.full((1, 769), 0)
#         for elem in c_batch:
#             while len(elem) < max_len:
#                 elem.append(default)
#         return q_batch, c_batch, ixs_batch, label_batch
