In [65]:
import torch
import torch.nn as nn

In [64]:
class CopyGenerator(nn.Module):
    """An implementation of pointer-generator networks
    :cite:`DBLP:journals/corr/SeeLM17`.

    These networks consider copying words
    directly from the source sequence.

    The copy generator is an extended version of the standard
    generator that computes three values.

    * :math:`p_{softmax}` the standard softmax over `tgt_dict`
    * :math:`p(z)` the probability of copying a word from
      the source
    * :math:`p_{copy}` the probility of copying a particular word.
      taken from the attention distribution directly.

    The model returns a distribution over the extend dictionary,
    computed as

    :math:`p(w) = p(z=1)  p_{copy}(w)  +  p(z=0)  p_{softmax}(w)`
    Args:
       input_size (int): size of input representation
       output_size (int): size of output vocabulary
       pad_idx (int)
    """

    def __init__(self, input_size, output_size, pad_idx):
        super(CopyGenerator, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        self.linear_copy = nn.Linear(input_size, 1)
        self.pad_idx = pad_idx

    def forward(self, hidden, attn, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by copying
        source words.

        Args:
           hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           src_map (FloatTensor):
               A sparse indicator matrix mapping each source word to
               its index in the "extended" vocab containing.
               ``(src_len, batch, extra_words)``
        """

        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.pad_idx] = -float('inf')
        prob = torch.softmax(logits, 1)

        # Probability of copying p(z=1) batch.
        p_copy = torch.sigmoid(self.linear_copy(hidden))
        # Probability of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob, 1 - p_copy)
        mul_attn = torch.mul(attn, p_copy)
        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)
        ).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1)


In [74]:
class CopyGeneratorLoss(nn.Module):
    """Copy generator criterion."""
    def __init__(self, vocab_size, force_copy, unk_index=0,
                 ignore_index=-100, eps=1e-20):
        super(CopyGeneratorLoss, self).__init__()
        self.force_copy = force_copy
        self.eps = eps
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index
        self.unk_index = unk_index

    def forward(self, scores, align, target):
        """
        Args:
            scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size
                whose sum along dim 1 is less than or equal to 1, i.e. cols
                softmaxed.
            align (LongTensor): ``(batch_size x tgt_len)``
            target (LongTensor): ``(batch_size x tgt_len)``
        """
        # probabilities assigned by the model to the gold targets
        vocab_probs = scores.gather(1, target.unsqueeze(1)).squeeze(1)

        # probability of tokens copied from source
        # offset the indices by vocabulary size.
        copy_ix = align.unsqueeze(1) + self.vocab_size
        print(copy_ix)
        copy_tok_probs = scores.gather(1, copy_ix).squeeze(1)
        # Set scores for unk to 0 and add eps
        # (those that should not be copied)
        copy_tok_probs[align == self.unk_index] = 0
        copy_tok_probs += self.eps  # to avoid -inf logs

        # find the indices in which you do not use the copy mechanism
        non_copy = align == self.unk_index
        if not self.force_copy:
            non_copy = non_copy | (target != self.unk_index)
            
        # If copy then use copy probs
        # If non-copy then use vocab probs
        probs = torch.where(
            non_copy, copy_tok_probs + vocab_probs, copy_tok_probs
        )

        loss = -probs.log()  # just NLLLoss; can the module be incorporated?
        # Drop padding.
        loss[target == self.ignore_index] = 0
        return loss


In [80]:
vocab_size = 4
batch_size = 3
tgt_len = 5
copy_size = 3

In [81]:
my_scores = torch.zeros(batch_size * tgt_len, vocab_size + copy_size)
for i in range(batch_size * tgt_len):
    my_scores[i,:] = torch.arange(7)
my_scores

tensor([[0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.]])

In [89]:
my_align = torch.randint(low=0, high=vocab_size - 1, size=(batch_size * tgt_len,)).long().view(-1)
my_align

tensor([1, 1, 1, 0, 1, 1, 0, 0, 1, 2, 1, 1, 0, 2, 0])

In [90]:
my_target = torch.randint(low=0, high=tgt_len - 1, size=(batch_size * tgt_len,)).long().view(-1)
my_target

tensor([2, 3, 1, 3, 2, 3, 1, 1, 0, 3, 3, 3, 0, 1, 2])

In [91]:
loss = CopyGeneratorLoss(vocab_size, force_copy=False)
loss(my_scores, my_align, my_target)

tensor([2., 3., 1., 3., 2., 3., 1., 1., 0., 3., 3., 3., 0., 1., 2.])
tensor([[5],
        [5],
        [5],
        [4],
        [5],
        [5],
        [4],
        [4],
        [5],
        [6],
        [5],
        [5],
        [4],
        [6],
        [4]])


tensor([-1.9459, -2.0794, -1.7918, -1.0986, -1.9459, -2.0794, -0.0000, -0.0000,
        -1.6094, -2.1972, -2.0794, -2.0794, 46.0517, -1.9459, -0.6931])