In [8]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'


In [9]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
from datasets import load_dataset, load_metric
datasets = load_dataset("squad_v2")

Found cached dataset squad_v2 (/root/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d)
100%|██████████| 2/2 [00:00<00:00, 378.99it/s]


In [11]:
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 130319
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 11873
    })
})

In [12]:
datasets["train"][0]

{'id': '56be85543aeaaa14008c9063',
 'title': 'Beyoncé',
 'context': 'Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny\'s Child. Managed by her father, Mathew Knowles, the group became one of the world\'s best-selling girl groups of all time. Their hiatus saw the release of Beyoncé\'s debut album, Dangerously in Love (2003), which established her as a solo artist worldwide, earned five Grammy Awards and featured the Billboard Hot 100 number-one singles "Crazy in Love" and "Baby Boy".',
 'question': 'When did Beyonce start becoming popular?',
 'answers': {'text': ['in the late 1990s'], 'answer_start': [269]}}

In [13]:
train_dict = datasets["train"][:100]
val_dict = datasets["train"][100:120]



def read_squad(dic):

    contexts = []
    questions = []
    answers = []
    for answerss in dic['answers']:
        for answer in answerss
            answers.append(answer)
    for questionss in dic["question"]:
        for question in questionss:
            questions.append(question)
    for contextss in dic["context"]:
        for context in contextss:
            contexts.append(context)

    return contexts, questions, answers
                
train_contexts, train_questions, train_answers = read_squad(train_dict)
val_contexts, val_questions, val_answers = read_squad(val_dict)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.


DPRContextEncoder(
  (ctx_encoder): DPREncoder(
    (bert_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)


In [None]:
def add_end_idx(answers, contexts):
    for answer, context in zip(answers, contexts):
        gold_text = answer['text']
        start_idx = answer['answer_start']
        end_idx = start_idx + len(gold_text)

        # sometimes squad answers are off by a character or two – fix this
        if context[start_idx:end_idx] == gold_text:
            answer['answer_end'] = end_idx
        elif context[start_idx-1:end_idx-1] == gold_text:
            answer['answer_start'] = start_idx - 1
            answer['answer_end'] = end_idx - 1     # When the gold label is off by one character
        elif context[start_idx-2:end_idx-2] == gold_text:
            answer['answer_start'] = start_idx - 2
            answer['answer_end'] = end_idx - 2     # When the gold label is off by two characters

add_end_idx(train_answers, train_contexts)
add_end_idx(val_answers, val_contexts)

In [None]:
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer

query_model_name = "facebook/dpr-question_encoder-single-nq-base"
passage_model_name = "facebook/dpr-ctx_encoder-single-nq-base"

query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(query_model_name)
passage_tokenizer = DPRContextEncoderTokenizer.from_pretrained(passage_model_name)

query_model = DPRQuestionEncoder.from_pretrained(query_model_name)
passage_model = DPRContextEncoder.from_pretrained(passage_model_name)

query_model.train()
passage_model.train()

In [None]:
train_query_encodings = query_tokenizer(train_questions, truncation=True, padding=True, return_tensors = 'pt')
train_context_encodings = passage_tokenizer(train_contexts, truncation=True, padding=True, return_tensors = 'pt')

val_query_encodings = query_tokenizer(val_questions, truncation=True, padding=True, return_tensors = 'pt')
val_context_encodings = passage_tokenizer(val_contexts, truncation=True, padding=True, return_tensors = 'pt')

In [None]:
class DPR(nn.Module):


  def __init__(self, query_model, passage_model, query_tokenizer, passage_tokenizer, 
              dense_size, freeze_params = 0.0, batch_size = 2, sample_size = 4):
    
    '''
    :query_model : The model that encodes queries to dense representation
    :passage_model : The model that encodes passages to dense representation
    :query_tokenizer : tokenizer for queries
    :passage_tokenizer : tokenizer for passages
    :passage_dict : dictionary of passages with their unique id
    :questions : A list of tuples with question and their correct passage id
    :dense_size : the dimension to which the DPR has to encode
    :freeze_params : the percentage of the parameters to be frozen
    :batch_size : the batch size for training
    :sample_size: the sample size for negative sampling
    '''
    super(DPR, self).__init__()
    self.query_model = query_model
    self.query_tokenizer = query_tokenizer
    self.passage_model = passage_model
    self.passage_tokenizer = passage_tokenizer
    self.freeze_params = freeze_params
    self.sample_size = sample_size
    self.batch_size = batch_size

    self.passage_to_dense = nn.Sequential(nn.Linear(768, dense_size * 2),
                                          nn.ReLU(),
                                          nn.Linear(dense_size * 2, dense_size),
                                          nn.GELU())
    
    self.query_to_dense = nn.Sequential(nn.Linear(768, dense_size * 2),
                                          nn.ReLU(),
                                          nn.Linear(dense_size * 2, dense_size),
                                          nn.GELU())
    self.log_softmax = nn.LogSoftmax(dim = 1)
    self.freeze_layers()


  # Freeze the first self.freeze_params % layers
  def freeze_layers(self):
    num_query_layers = sum(1 for _ in self.query_model.parameters())
    num_passage_layers = sum(1 for _ in self.passage_model.parameters())

    for parameters in list(self.query_model.parameters())[:int(self.freeze_params * num_query_layers)]:
      parameters.requires_grad = False

    for parameters in list(self.query_model.parameters())[int(self.freeze_params * num_query_layers):]:
      parameters.requires_grad = True

    for parameters in list(self.passage_model.parameters())[:int(self.freeze_params * num_passage_layers)]:
      parameters.requires_grad = False

    for parameters in list(self.passage_model.parameters())[int(self.freeze_params * num_passage_layers):]:
      parameters.requires_grad = True

  def get_passage_vectors(self, passage):
    p_vector = self.passage_model(input_ids = passage.input_ids, 
                                  attention_mask = passage.attention_mask)
    p_vector = self.query_to_dense(p_vector.pooler_output)
    return p_vector

  def get_query_vector(self, query):
    q_vector = self.query_model(input_ids = query.input_ids, 
                                attention_mask = query.attention_mask)
    q_vector = self.query_to_dense(q_vector.pooler_output)
    return q_vector

  def dot_product(self, q_vector, p_vector):
    q_vector = q_vector.unsqueeze(1)
    sim = torch.matmul(q_vector, torch.transpose(p_vector, -2, -1))
    return sim

  def forward(self, context_input_ids, context_attention_mask, query_input_ids, query_attention_mask):
    dense_passage = self.passage_model(input_ids = context_input_ids, attention_mask = context_attention_mask)
    dense_query = self.query_model(input_ids = query_input_ids, attention_mask = query_attention_mask)
    dense_passage = dense_passage['pooler_output']
    dense_query = dense_query['pooler_output']
    dense_passage = self.passage_to_dense(dense_passage)
    dense_query = self.query_to_dense(dense_query)
    similarity_score = self.dot_product(dense_query, dense_passage)
    similarity_score = similarity_score.squeeze(1)
    logits = self.log_softmax(similarity_score)
    return logits