In [1]:
import copy
import json
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
import numpy as np
import math
import numpy
import torch
from torch.autograd import Variable

In [None]:
def _sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
        
    batch_size = sequence_length.size(0)
    
    seq_range = torch.range(0, max_len - 1).long()
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_range_expand = Variable(seq_range_expand)
    if sequence_length.is_cuda:
        seq_range_expand = seq_range_expand.cuda()
    seq_length_expand = (sequence_length.unsqueeze(1)
                         .expand_as(seq_range_expand))
    return seq_range_expand < seq_length_expand


def compute_loss(logits, target, length):
    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.
    Returns:
        loss: An average loss value masked by the length.
    """

    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())
    # mask: (batch, max_len)
    mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    return loss

In [None]:
# logits: (batch, seq_len, num_classes)
# target: (batch, seq_len)

batch_size = 4
nb_classes = 2
seq_len = 3

logits = torch.rand(batch_size, seq_len, nb_classes)
target = torch.rand(batch_size, seq_len)
target = target.long()

print(logits)
print(target)

In [None]:
# logits_flat: (batch * max_len, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
print(logits_flat)

# log_probs_flat: (batch * max_len, num_classes)
log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
print(log_probs_flat)

# target_flat: (batch * max_len, 1)
target_flat = target.view(-1, 1)
print(target_flat)

# losses_flat: (batch * max_len, 1)
# dimensions need to be the same size except for the dim you are gathering on
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
print(losses_flat)

# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
print(losses)

In [None]:
# mask: (batch, max_len)
mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
print(mask)

losses = losses * mask.float()
print(losses)

loss = losses.sum() / length.float().sum()
print(loss)

## Type 2 - 3d target

In [2]:
# logits: (batch, seq_len, num_classes)
# target: (batch, seq_len)
# length: (batch,)
batch_size = 4
nb_classes = 10
nb_labels = 3
seq_len = 5

logits = torch.rand(batch_size, seq_len, nb_classes)
# notice target has a third dimension of nb_labels
target = torch.rand(batch_size, seq_len, nb_labels)
target = target.long()

print(logits.shape)
print(target.shape)

torch.Size([4, 5, 10])
torch.Size([4, 5, 3])


In [3]:
# create a target tensor with some padded elements
target = ([[[1, 0, 0],
         [1, 2, 0],
         [1, 4, 0],
         [1, 3, 0],
         [1, 3, 0]],

        [[2, 2, 0],
         [5, 4, 0],
         [1, 8, 0],
         [3, 1, 0],
         [6, 8, 0]],

        [[3, 9, 0],
         [4, 6, 0],
         [4, 3, 0],
         [3, 1, 0],
         [6, 3, 0]],

        [[4, 6, 0],
         [2, 3, 3],
         [1, 2, 9],
         [1, 2, 6],
         [7, 9, 8]]])

target = torch.LongTensor(target)
print(target.shape)

torch.Size([4, 5, 3])


In [4]:
target = torch.randint(6, size=(4,5,3))
print(target)

tensor([[[1, 0, 0],
         [3, 3, 4],
         [4, 5, 1],
         [4, 2, 2],
         [5, 5, 2]],

        [[1, 5, 5],
         [4, 1, 2],
         [1, 4, 5],
         [5, 2, 2],
         [3, 4, 4]],

        [[5, 4, 0],
         [2, 1, 3],
         [2, 3, 0],
         [1, 0, 1],
         [2, 5, 1]],

        [[5, 4, 2],
         [2, 1, 1],
         [3, 2, 0],
         [3, 4, 2],
         [0, 1, 4]]])


In [5]:
# logits_flat: (batch * max_len, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
#print("== logits flat == \n", logits_flat)

# log_probs_flat: (batch * max_len, num_classes)
log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
print("== log probs flat == \n", log_probs_flat)

# target_flat: (batch * max_len, 1)
#target_flat = target.view(-1, 1)
target_flat = target.view(-1, nb_labels)
print("== target flat == \n", target_flat)

# losses_flat: (batch * max_len, 1)
# dimensions need to be the same size except for the dim you are gathering on
# gathers from log probs the target index (row dim)
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
print("== losses flat == \n", losses_flat)

# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
print("== losses == \n", losses)

== log probs flat == 
 tensor([[-2.5541, -2.6483, -1.9175, -2.3984, -1.9399, -2.5420, -2.1021, -2.6512,
         -1.9804, -2.8061],
        [-2.3979, -2.3752, -2.3320, -2.3054, -2.0664, -2.2317, -2.7316, -1.9382,
         -2.2414, -2.6605],
        [-2.2840, -2.2207, -1.8146, -2.6244, -2.6774, -2.7192, -1.8871, -2.0133,
         -2.7447, -2.6564],
        [-2.6910, -1.9292, -2.2400, -2.8074, -2.0766, -2.3287, -2.3621, -2.3430,
         -2.8001, -1.9223],
        [-2.5979, -2.2199, -2.5472, -2.8370, -1.9948, -2.1050, -2.1352, -1.9835,
         -2.9466, -2.1731],
        [-2.7588, -2.6624, -1.9396, -2.5732, -2.3752, -2.6498, -2.6546, -2.0218,
         -1.9942, -1.9360],
        [-2.0923, -2.7621, -2.3564, -2.4340, -1.9870, -2.7169, -1.9742, -2.6905,
         -2.0041, -2.4499],
        [-2.4599, -2.6612, -2.3287, -1.9371, -2.4446, -2.5606, -2.4221, -2.0665,
         -1.8560, -2.6960],
        [-2.0872, -1.9946, -2.0831, -2.3067, -2.2554, -2.6125, -2.8531, -2.2836,
         -2.2313, -2.650

## Padding

In [13]:
# = torch.LongTensor(a)

def mask_tensor(tensor):
    return (tensor != 0).long()

tensor_mask = mask_tensor(target).float()
print(tensor_mask)

tensor([[[1., 0., 0.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 0.],
         [1., 1., 1.],
         [1., 1., 0.],
         [1., 0., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 0.],
         [1., 1., 1.],
         [0., 1., 1.]]])


In [12]:
# multiply losses by mask
masked_losses = losses * tensor_mask
print(masked_losses)

tensor([[[2.6483, 0.0000, 0.0000],
         [2.3054, 2.3054, 2.0664],
         [2.6774, 2.7192, 2.2207],
         [2.0766, 2.2400, 2.2400],
         [2.1050, 2.1050, 2.5472]],

        [[2.6624, 2.6498, 2.6498],
         [1.9870, 2.7621, 2.3564],
         [2.6612, 2.4446, 2.5606],
         [2.6125, 2.0831, 2.0831],
         [2.0581, 1.9493, 1.9493]],

        [[2.4236, 2.3315, 0.0000],
         [2.2571, 2.1235, 2.3943],
         [2.8639, 2.4257, 0.0000],
         [2.1213, 0.0000, 2.1213],
         [2.2974, 2.5731, 2.0028]],

        [[2.6323, 2.5186, 2.6361],
         [2.7851, 2.4291, 2.4291],
         [1.9899, 1.9217, 0.0000],
         [2.6419, 2.0817, 2.5879],
         [0.0000, 2.6173, 2.4899]]])


now the losses are being multiplied by the mask, zeroing out the loss of the padded labels
we need to sum along the row-dim so for each item in the sequence we take its summed loss
this can probably be a 2d tensor now of (batch, seq_len) as the seq_len dimension will contain a scalar loss value,
then we can average all these losses over the batch