In [41]:
from transformers import PreTrainedModel, DPRConfig, DPRReaderOutput, DPRPretrainedReader, DPRReaderTokenizer
from transformers.models.dpr.modeling_dpr import DPREncoder, DPR_READER_INPUTS_DOCSTRING, _CONFIG_FOR_DOC
from transformers.file_utils import add_start_docstrings_to_model_forward, replace_return_docstrings
from torch import Tensor, nn
from typing import Optional, Union, List, Tuple

In [68]:
class DPRSpanPredictor(PreTrainedModel):

    base_model_prefix = "encoder"

    def __init__(self, config: DPRConfig):
        super().__init__(config)
        self.encoder = DPREncoder(config)
        self.qa_outputs = nn.Linear(self.encoder.embeddings_size, 2)
        self.qa_classifier = nn.Linear(self.encoder.embeddings_size, 1)
        self.init_weights()

    def forward(
        self,
        input_ids: Tensor,
        attention_mask: Tensor,
        inputs_embeds: Optional[Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = False,
    ) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]:
        # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
        n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
        # feed encoder
        outputs = self.encoder(
            input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]

        # compute logits
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)
        relevance_logits = self.qa_classifier(sequence_output[:, 0, :])

        # resize
        start_logits = start_logits.view(n_passages, sequence_length)
        end_logits = end_logits.view(n_passages, sequence_length)
        relevance_logits = relevance_logits.view(n_passages)

        if not return_dict:
            return (start_logits, end_logits, relevance_logits) + outputs[2:]

        return DPRReaderOutput(
            start_logits=start_logits,
            end_logits=end_logits,
            relevance_logits=relevance_logits,
#             hidden_states=outputs.hidden_states,
#             attentions=outputs.attentions,
        )

    def init_weights(self):
        self.encoder.init_weights()


In [61]:
class DPRReader(DPRPretrainedReader):
    def __init__(self, config: DPRConfig):
        super().__init__(config)
        self.config = config
        self.span_predictor = DPRSpanPredictor(config)
        self.init_weights()

    @add_start_docstrings_to_model_forward(DPR_READER_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=DPRReaderOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: Optional[Tensor] = None,
        attention_mask: Optional[Tensor] = None,
        inputs_embeds: Optional[Tensor] = None,
        output_attentions: bool = None,
        output_hidden_states: bool = None,
        return_dict=None,
    ) -> Union[DPRReaderOutput, Tuple[Tensor, ...]]:
        r"""
        Return:

        Examples::

            >>> from transformers import DPRReader, DPRReaderTokenizer
            >>> tokenizer = DPRReaderTokenizer.from_pretrained('facebook/dpr-reader-single-nq-base')
            >>> model = DPRReader.from_pretrained('facebook/dpr-reader-single-nq-base')
            >>> encoded_inputs = tokenizer(
            ...         questions=["What is love ?"],
            ...         titles=["Haddaway"],
            ...         texts=["'What Is Love' is a song recorded by the artist Haddaway"],
            ...         return_tensors='pt'
            ...     )
            >>> outputs = model(**encoded_inputs)
            >>> start_logits = outputs.stat_logits
            >>> end_logits = outputs.end_logits
            >>> relevance_logits = outputs.relevance_logits

        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if attention_mask is None:
            attention_mask = torch.ones(input_shape, device=device)

        return self.span_predictor(
            input_ids,
            attention_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


In [71]:
tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")

encoded_inputs = tokenizer(
    questions="What is love ?",
    titles="Haddaway",
    texts="What Is Love is a song recorded by the artist Haddaway",
    padding=True,
    return_tensors="pt",
)


In [72]:
import random
import torch
global_rng = random.Random()
torch_device='cpu'
def ids_tensor(shape, vocab_size, rng=None, name=None):
    #  Creates a random int32 tensor of the shape within the vocab size
    if rng is None:
        rng = global_rng

    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(rng.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous()


def random_attention_mask(shape, rng=None, name=None):
    attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None)
    # make sure that at least one token is attended to for each batch
    attn_mask[:, -1] = 1
    return attn_mask

In [73]:
attn = random_attention_mask([1, 23]) # self.batch_size, self.seq_length

In [75]:
encoded_inputs['attention_mask'] = attn

In [79]:
encoded_inputs

{'input_ids': tensor([[ 101, 2054, 2003, 2293, 1029,  102, 2018, 2850, 4576,  102, 2054, 2003,
         2293, 2003, 1037, 2299, 2680, 2011, 1996, 3063, 2018, 2850, 4576]]), 'attention_mask': tensor([[1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1]])}

In [76]:
outputs = model(**encoded_inputs)

In [78]:
outputs

DPRReaderOutput(start_logits=tensor([[-3.0630, -4.7473, -5.8192, -5.6407, -5.8589, -9.0629, -4.9761, -6.8090,
         -7.5950, -9.0560, -4.7091, -4.9144, -5.0070, -5.2773, -4.5161, -5.3637,
         -4.3490, -4.1066, -1.6287, -2.3917,  1.7558, -5.2994, -2.0889]],
       grad_fn=<ViewBackward>), end_logits=tensor([[-3.1214, -5.8755, -6.2211, -5.3107, -5.7220, -4.5273, -7.4018, -7.3266,
         -5.7475, -4.5124, -4.2047, -5.1659, -3.1289, -6.3303, -4.7987, -5.4662,
         -4.9542, -5.1752, -3.3463, -2.2221, -2.0111, -4.1174, -0.6945]],
       grad_fn=<ViewBackward>), relevance_logits=tensor([-11.7683], grad_fn=<ViewBackward>))

In [87]:
dist_start_logits = torch.softmax(outputs.start_logits, dim=1)
print(dist_start_logits)

tensor([[7.3607e-03, 1.3659e-03, 4.6761e-04, 5.5902e-04, 4.4941e-04, 1.8246e-05,
         1.0865e-03, 1.7380e-04, 7.9194e-05, 1.8373e-05, 1.4191e-03, 1.1558e-03,
         1.0535e-03, 8.0394e-04, 1.7211e-03, 7.3742e-04, 2.0342e-03, 2.5923e-03,
         3.0890e-02, 1.4403e-02, 9.1133e-01, 7.8641e-04, 1.9496e-02]],
       grad_fn=<SoftmaxBackward>)


In [91]:
predicted_start_pos = torch.argmax(dist_start_logits, dim=1)
print(predicted_start_pos.item())

20
