In [1]:
import torch
import torch.nn as nn

In [75]:
class CopyGeneratorLoss(nn.Module):
    """Copy generator criterion."""
    def __init__(self, vocab_size, force_copy, unk_index=-1,
                 ignore_index=-100, eps=1e-20):
        super(CopyGeneratorLoss, self).__init__()
        self.force_copy = force_copy
        self.eps = eps
        self.vocab_size = vocab_size
        self.ignore_index = ignore_index
        self.unk_index = unk_index

    def forward(self, scores, align, target):
        """
        Args:
            scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size
                whose sum along dim 1 is less than or equal to 1, i.e. cols
                softmaxed.
            align (LongTensor): ``(batch_size x tgt_len)``
            target (LongTensor): ``(batch_size x tgt_len)``
        """
        # probabilities assigned by the model to the gold indices vocabulary tokens 
        vocab_probs = scores.gather(1, target.unsqueeze(1)).squeeze(1)
        print(vocab_probs)
        # probability of tokens copied from source
        # offset the indices by vocabulary size.
        copy_ix = align.unsqueeze(1) + self.vocab_size
        print(copy_ix)
        copy_tok_probs = scores.gather(1, copy_ix).squeeze(1)
        print(copy_tok_probs)
        # Set scores for unk to 0 and add eps
        # (those that should not be copied)
        copy_tok_probs[align == self.unk_index] = 0
        copy_tok_probs += self.eps  # to avoid -inf logs

        # find the indices in which you do not use the copy mechanism
        non_copy = align == self.unk_index # tensor([-1,  1,  2, -1, -1, -1, -1])
        print(non_copy)
        print(self.unk_index)
            
        # If copy then use copy probs
        # If non-copy then use vocab probs
        probs = torch.where(
            non_copy, copy_tok_probs + vocab_probs, copy_tok_probs
        )
        print(probs)

        loss = -probs.log()  # just NLLLoss; can the module be incorporated?
        # Drop padding.
        loss[target == self.ignore_index] = 0
        return loss


In [76]:
vocab_size = 4 # special tokens
batch_size = 1
tgt_len = 7
copy_size = 3  # input entity embed

In [77]:
my_scores = torch.zeros(batch_size * tgt_len, vocab_size + copy_size)
for i in range(batch_size * tgt_len):
    my_scores[i,:] = torch.arange(7)
my_scores

tensor([[0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.],
        [0., 1., 2., 3., 4., 5., 6.]])

In [78]:
tgt_plan = torch.tensor([-2, 1, 2, -1, -1, -3, -4])

In [79]:
my_target = torch.randint(low=0, high=vocab_size - 1, size=(batch_size * tgt_len,)).long().view(-1)
my_target = torch.where(tgt_plan < 0, tgt_plan, 0) + 4
my_target

tensor([2, 4, 4, 3, 3, 1, 0])

In [80]:
my_align = torch.randint(low=-1, high=copy_size - 1, size=(batch_size * tgt_len,)).long().view(-1)
my_align = torch.where(tgt_plan >= 0, tgt_plan, -1)
my_align

tensor([-1,  1,  2, -1, -1, -1, -1])

In [81]:
loss = CopyGeneratorLoss(vocab_size, force_copy=False)
loss(my_scores, my_align, my_target)

tensor([2., 4., 4., 3., 3., 1., 0.])
tensor([[3],
        [5],
        [6],
        [3],
        [3],
        [3],
        [3]])
tensor([3., 5., 6., 3., 3., 3., 3.])
tensor([ True, False, False,  True,  True,  True,  True])
-1
tensor([2.0000e+00, 5.0000e+00, 6.0000e+00, 3.0000e+00, 3.0000e+00, 1.0000e+00,
        1.0000e-20])


tensor([-0.6931, -1.6094, -1.7918, -1.0986, -1.0986, -0.0000, 46.0517])

In [2]:
class CopyGenerator(nn.Module):
    """An implementation of pointer-generator networks
    :cite:`DBLP:journals/corr/SeeLM17`.

    These networks consider copying words
    directly from the source sequence.

    The copy generator is an extended version of the standard
    generator that computes three values.

    * :math:`p_{softmax}` the standard softmax over `tgt_dict`
    * :math:`p(z)` the probability of copying a word from
      the source
    * :math:`p_{copy}` the probility of copying a particular word.
      taken from the attention distribution directly.

    The model returns a distribution over the extend dictionary,
    computed as

    :math:`p(w) = p(z=1)  p_{copy}(w)  +  p(z=0)  p_{softmax}(w)`
    Args:
       input_size (int): size of input representation
       output_size (int): size of output vocabulary
       pad_idx (int)
    """

    def __init__(self, input_size, output_size, pad_idx):
        super(CopyGenerator, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
        self.linear_copy = nn.Linear(input_size, 1)
        self.pad_idx = pad_idx

    def forward(self, hidden, attn, src_map):
        """
        Compute a distribution over the target dictionary
        extended by the dynamic dictionary implied by copying
        source words.

        Args:
           hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
           attn (FloatTensor): attn for each ``(batch x tlen, input_size)``
           src_map (FloatTensor):
               A sparse indicator matrix mapping each source word to
               its index in the "extended" vocab containing.
               ``(src_len, batch, extra_words)``
        """

        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = attn.size()
        slen_, batch, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.pad_idx] = -float('inf')
        prob = torch.softmax(logits, 1)

        # Probability of copying p(z=1) batch.
        p_copy = torch.sigmoid(self.linear_copy(hidden))
        # Probability of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob, 1 - p_copy)
        mul_attn = torch.mul(attn, p_copy)
        copy_prob = torch.bmm(
            mul_attn.view(-1, batch, slen).transpose(0, 1),
            src_map.transpose(0, 1)
        ).transpose(0, 1)
        copy_prob = copy_prob.contiguous().view(-1, cvocab)
        return torch.cat([out_prob, copy_prob], 1)


In [28]:
vocab_size = 4  # doc_start, edu_end, doc_end, pad
pad_idx = -4
input_size = 5
batch_size = 2
tlen = 3 # sentence length

In [29]:
copy_generator = CopyGenerator(input_size, vocab_size, pad_idx)

In [30]:
input_emb = torch.zeros(batch_size * tgt_len, input_size)
for i in range(batch_size * tgt_len):
    input_emb[i,:] = torch.rand(input_size)
input_emb

tensor([[0.8035, 0.3979, 0.7843, 0.4679, 0.8035],
        [0.7226, 0.4758, 0.2203, 0.0088, 0.0677],
        [0.8937, 0.3917, 0.6902, 0.9131, 0.7289],
        [0.5425, 0.4195, 0.9561, 0.8335, 0.5336],
        [0.9044, 0.2503, 0.8863, 0.9474, 0.1896],
        [0.3519, 0.8827, 0.7603, 0.7895, 0.0222],
        [0.4087, 0.5756, 0.2003, 0.7533, 0.4260],
        [0.9416, 0.5946, 0.1377, 0.0032, 0.8395],
        [0.0773, 0.6500, 0.5805, 0.1012, 0.8393],
        [0.2741, 0.1434, 0.5465, 0.2128, 0.1882]])

In [32]:
attn = torch.zeros(batch_size * tgt_len, input_size)
for i in range(batch_size * tgt_len):
    attn[i,:] = torch.rand(input_size)
attn

tensor([[0.7473, 0.8555, 0.9333, 0.7546, 0.1779],
        [0.5429, 0.1780, 0.7941, 0.3028, 0.5348],
        [0.4624, 0.3347, 0.9468, 0.3113, 0.0423],
        [0.6377, 0.8386, 0.8025, 0.1749, 0.7684],
        [0.3613, 0.0896, 0.5275, 0.0442, 0.7404],
        [0.5135, 0.1966, 0.0225, 0.9383, 0.2640],
        [0.5849, 0.7887, 0.4949, 0.0173, 0.1940],
        [0.7462, 0.0195, 0.9299, 0.6881, 0.3056],
        [0.8489, 0.4010, 0.1593, 0.3089, 0.8572],
        [0.2396, 0.0105, 0.7221, 0.2826, 0.1396]])

In [33]:
copy_generator(input_emb, attn)

TypeError: forward() missing 1 required positional argument: 'src_map'

tensor([2, 4, 4, 3, 4, 4, 3, 4, 3, 1, 0, 0])

tensor([-1,  2,  3, -1,  5,  1, -1,  7, -1, -1, -1, -1])