In [1]:
from transformers import AlbertTokenizerFast, AlbertModel
from transformers import RobertaTokenizerFast, RobertaModel
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import json
import numpy as np
from pathlib import Path
from tqdm.notebook import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def compare_models(model1, model2):
  for p1, p2 in zip(model1.parameters(), model2.parameters()):
    if p1.data.ne(p2.data).sum() > 0:
        return False
  return True

In [3]:
def read_squad(path):
    path = Path(path)
    with open(path, 'rb') as f:
        squad_dict = json.load(f)

    contexts = []
    questions = []
    answers = []
    for group in squad_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)

    return contexts, questions, answers

train_contexts, train_questions, train_answers = read_squad('train-v2.0.json')
val_contexts, val_questions, val_answers = read_squad('dev-v2.0.json')

FileNotFoundError: [Errno 2] No such file or directory: 'train-v2.0.json'

In [None]:
def add_end_idx(answers, contexts):
    for answer, context in zip(answers, contexts):
        gold_text = answer['text']
        start_idx = answer['answer_start']
        end_idx = start_idx + len(gold_text)

        # sometimes squad answers are off by a character or two – fix this
        if context[start_idx:end_idx] == gold_text:
            answer['answer_end'] = end_idx
        elif context[start_idx-1:end_idx-1] == gold_text:
            answer['answer_start'] = start_idx - 1
            answer['answer_end'] = end_idx - 1     # When the gold label is off by one character
        elif context[start_idx-2:end_idx-2] == gold_text:
            answer['answer_start'] = start_idx - 2
            answer['answer_end'] = end_idx - 2     # When the gold label is off by two characters

add_end_idx(train_answers, train_contexts)
add_end_idx(val_answers, val_contexts)

In [None]:
query_model_name = "facebook/dpr-question_encoder-single-nq-base"
passage_model_name = "facebook/dpr-ctx_encoder-single-nq-base"

query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(query_model_name)
passage_tokenizer = DPRContextEncoderTokenizer.from_pretrained(passage_model_name)
query_model = DPRQuestionEncoder.from_pretrained(query_model_name)
passage_model = DPRContextEncoder.from_pretrained(passage_model_name)

query_model.train()
passage_model.train()
print()

In [None]:
train_query_encodings = query_tokenizer(train_questions, truncation=True, padding=True, return_tensors = 'pt')
train_context_encodings = passage_tokenizer(train_contexts, truncation=True, padding=True, return_tensors = 'pt')

val_query_encodings = query_tokenizer(val_questions, truncation=True, padding=True, return_tensors = 'pt')
val_context_encodings = passage_tokenizer(val_contexts, truncation=True, padding=True, return_tensors = 'pt')

In [None]:
class DPR(nn.Module):


  def __init__(self, query_model, passage_model, query_tokenizer, passage_tokenizer, 
              dense_size, freeze_params = 0.0, batch_size = 2, sample_size = 4):
    
    '''
    :query_model : The model that encodes queries to dense representation
    :passage_model : The model that encodes passages to dense representation
    :query_tokenizer : tokenizer for queries
    :passage_tokenizer : tokenizer for passages
    :passage_dict : dictionary of passages with their unique id
    :questions : A list of tuples with question and their correct passage id
    :dense_size : the dimension to which the DPR has to encode
    :freeze_params : the percentage of the parameters to be frozen
    :batch_size : the batch size for training
    :sample_size: the sample size for negative sampling
    '''
    super(DPR, self).__init__()
    self.query_model = query_model
    self.query_tokenizer = query_tokenizer
    self.passage_model = passage_model
    self.passage_tokenizer = passage_tokenizer
    self.freeze_params = freeze_params
    self.sample_size = sample_size
    self.batch_size = batch_size

    self.passage_to_dense = nn.Sequential(nn.Linear(768, dense_size * 2),
                                          nn.ReLU(),
                                          nn.Linear(dense_size * 2, dense_size),
                                          nn.GELU())
    
    self.query_to_dense = nn.Sequential(nn.Linear(768, dense_size * 2),
                                          nn.ReLU(),
                                          nn.Linear(dense_size * 2, dense_size),
                                          nn.GELU())
    self.log_softmax = nn.LogSoftmax(dim = 1)
    self.freeze_layers()


  # Freeze the first self.freeze_params % layers
  def freeze_layers(self):
    num_query_layers = sum(1 for _ in self.query_model.parameters())
    num_passage_layers = sum(1 for _ in self.passage_model.parameters())

    for parameters in list(self.query_model.parameters())[:int(self.freeze_params * num_query_layers)]:
      parameters.requires_grad = False

    for parameters in list(self.query_model.parameters())[int(self.freeze_params * num_query_layers):]:
      parameters.requires_grad = True

    for parameters in list(self.passage_model.parameters())[:int(self.freeze_params * num_passage_layers)]:
      parameters.requires_grad = False

    for parameters in list(self.passage_model.parameters())[int(self.freeze_params * num_passage_layers):]:
      parameters.requires_grad = True

  def get_passage_vectors(self, passage):
    p_vector = self.passage_model(input_ids = passage.input_ids, 
                                  attention_mask = passage.attention_mask)
    p_vector = self.query_to_dense(p_vector.pooler_output)
    return p_vector

  def get_query_vector(self, query):
    q_vector = self.query_model(input_ids = query.input_ids, 
                                attention_mask = query.attention_mask)
    q_vector = self.query_to_dense(q_vector.pooler_output)
    return q_vector

  def dot_product(self, q_vector, p_vector):
    q_vector = q_vector.unsqueeze(1)
    sim = torch.matmul(q_vector, torch.transpose(p_vector, -2, -1))
    return sim

  def forward(self, context_input_ids, context_attention_mask, query_input_ids, query_attention_mask):
    dense_passage = self.passage_model(input_ids = context_input_ids, attention_mask = context_attention_mask)
    dense_query = self.query_model(input_ids = query_input_ids, attention_mask = query_attention_mask)
    dense_passage = dense_passage['pooler_output']
    dense_query = dense_query['pooler_output']
    dense_passage = self.passage_to_dense(dense_passage)
    dense_query = self.query_to_dense(dense_query)
    similarity_score = self.dot_product(dense_query, dense_passage)
    similarity_score = similarity_score.squeeze(1)
    logits = self.log_softmax(similarity_score)
    return logits

In [None]:
dpr_model = DPR(query_model = query_model, 
                passage_model = passage_model, 
                query_tokenizer = query_tokenizer, 
                passage_tokenizer = passage_tokenizer, 
                dense_size = 64,
                freeze_params = 0.3,
                batch_size = 2,
                sample_size = 8)

dpr_model.train()


sum(p.numel() for p in dpr_model.parameters()), sum(p.numel() for p in dpr_model.parameters() if p.requires_grad == True)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
criterion = nn.NLLLoss()
optimizer = torch.optim.AdamW(dpr_model.parameters(), lr = 5e-5)

In [None]:
batch_size = 1
neg_samples = 2
num_questions = len(train_context_encodings.input_ids)

In [None]:
def get_batch():
  true = []
  context_input_ids_tensor = []
  context_attention_mask_tensor = []
  optimizer.zero_grad()

  true.append(1)

  ## Selecting the query and its positive context
  idx = random.randint(0, num_questions)
  context_input_ids = train_context_encodings.input_ids[idx]
  context_attention_mask = train_context_encodings.attention_mask[idx]
  query_input_ids = train_query_encodings.input_ids[idx]
  query_attention_mask = train_query_encodings.attention_mask[idx]
  context_input_ids_tensor.append(context_input_ids)
  context_attention_mask_tensor.append(context_attention_mask)

  ## Selecting the negative contexts which is hardcoded as 100 index from the selected positive context

  for j in range(neg_samples):
    true.append(0)
    neg_idx_1 = idx + 100
    neg_idx_2 = idx - 100
    if neg_idx_1 < num_questions:
      context_input_ids = train_context_encodings.input_ids[neg_idx_1]
      context_attention_mask = train_context_encodings.attention_mask[neg_idx_1]
      context_input_ids_tensor.append(context_input_ids)
      context_attention_mask_tensor.append(context_attention_mask)

    elif neg_idx_2 > 0:
      context_input_ids = train_context_encodings.input_ids[neg_idx_2]
      context_attention_mask = train_context_encodings.attention_mask[neg_idx_2]
      context_input_ids_tensor.append(context_input_ids)
      context_attention_mask_tensor.append(context_attention_mask)

  
  context_input_ids_tensor = torch.stack(context_input_ids_tensor)
  context_attention_mask_tensor = torch.stack(context_attention_mask_tensor)
  query_input_ids = query_input_ids.unsqueeze(0)
  query_attention_mask = query_attention_mask.unsqueeze(0)
  return context_input_ids_tensor, context_attention_mask_tensor, query_input_ids, query_attention_mask, true

In [None]:
batch_loss = 0
for i in range(1, 500):
  context_input_ids_tensor, context_attention_mask_tensor,query_input_ids,query_attention_mask, true = get_batch()
  pred = dpr_model(context_input_ids_tensor, context_attention_mask_tensor,query_input_ids,query_attention_mask)
  
  true = torch.tensor([0])
  #print(pred.size(), true.size())
  loss = criterion(pred, true)
  loss.backward()
  batch_loss += loss.item()
  optimizer.step()
  if i%20 == 0:
    print(f"Batch : {int(i/20)}  Loss : {batch_loss/20}")
    batch_loss = 0

In [None]:
save_dir = "./content/drive/MyDrive/Haystack/dpr_3.pt"
torch.save(dpr_model, save_dir)