In [1]:
import copy
import json
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union

import math
import numpy
import torch
from torch.autograd import Variable

In [2]:
def _sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
        
    batch_size = sequence_length.size(0)
    
    seq_range = torch.range(0, max_len - 1).long()
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_range_expand = Variable(seq_range_expand)
    if sequence_length.is_cuda:
        seq_range_expand = seq_range_expand.cuda()
    seq_length_expand = (sequence_length.unsqueeze(1)
                         .expand_as(seq_range_expand))
    return seq_range_expand < seq_length_expand


def compute_loss(logits, target, length):
    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.
    Returns:
        loss: An average loss value masked by the length.
    """

    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())
    # mask: (batch, max_len)
    mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    return loss

In [3]:
# logits: (batch, seq_len, num_classes)
# target: (batch, seq_len)
# length: (batch,)
batch_size = 4
nb_classes = 2
seq_len = 3

logits = torch.rand(batch_size, seq_len, nb_classes)
target = torch.rand(batch_size, seq_len)
target = target.long()

print(logits)
print(target)

tensor([[[0.6333, 0.8669],
         [0.1677, 0.7046],
         [0.6290, 0.3936]],

        [[0.6042, 0.6700],
         [0.7108, 0.1940],
         [0.5285, 0.3843]],

        [[0.2732, 0.4841],
         [0.2949, 0.1905],
         [0.7884, 0.4613]],

        [[0.5274, 0.4262],
         [0.3142, 0.2303],
         [0.5850, 0.4656]]])
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])


In [19]:
length = ([[1],
        [2],
        [3],
        [4]])

length = torch.LongTensor(4)
length.shape

torch.Size([4])

In [None]:
#loss = compute_loss(logits, target, length)

In [20]:
# logits_flat: (batch * max_len, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
print(logits_flat)

# log_probs_flat: (batch * max_len, num_classes)
log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
print(log_probs_flat)

# target_flat: (batch * max_len, 1)
target_flat = target.view(-1, 1)
print(target_flat)

# losses_flat: (batch * max_len, 1)
# dimensions need to be the same size except for the dim you are gathering on
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
print(losses_flat)

# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
print(losses)

tensor([[0.6333, 0.8669],
        [0.1677, 0.7046],
        [0.6290, 0.3936],
        [0.6042, 0.6700],
        [0.7108, 0.1940],
        [0.5285, 0.3843],
        [0.2732, 0.4841],
        [0.2949, 0.1905],
        [0.7884, 0.4613],
        [0.5274, 0.4262],
        [0.3142, 0.2303],
        [0.5850, 0.4656]])
tensor([[-0.8168, -0.5831],
        [-0.9972, -0.4603],
        [-0.5824, -0.8177],
        [-0.7266, -0.6608],
        [-0.4678, -0.9846],
        [-0.6236, -0.7679],
        [-0.8041, -0.5933],
        [-0.6423, -0.7467],
        [-0.5429, -0.8700],
        [-0.6438, -0.7450],
        [-0.6521, -0.7360],
        [-0.6352, -0.7547]])
tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0],
        [0]])
tensor([[0.8168],
        [0.9972],
        [0.5824],
        [0.7266],
        [0.4678],
        [0.6236],
        [0.8041],
        [0.6423],
        [0.5429],
        [0.6438],
        [0.65

In [21]:
# mask: (batch, max_len)
mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
print(mask)

losses = losses * mask.float()
print(losses)

loss = losses.sum() / length.float().sum()
print(loss)

tensor([[1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1]], dtype=torch.uint8)
tensor([[0.8168, 0.9972, 0.5824],
        [0.7266, 0.4678, 0.6236],
        [0.8041, 0.6423, 0.5429],
        [0.6438, 0.6521, 0.6352]])
tensor(9.3689e-19)


  import sys


## Type 2 - 3d target

In [59]:
# logits: (batch, seq_len, num_classes)
# target: (batch, seq_len)
# length: (batch,)
batch_size = 4
nb_classes = 10
nb_labels = 2
seq_len = 3

logits = torch.rand(batch_size, seq_len, nb_classes)
target = torch.rand(batch_size, seq_len, nb_labels)
target = target.long()

print(logits)
print(target)

tensor([[[0.6469, 0.4497, 0.5300, 0.8275, 0.4541, 0.9426, 0.2814, 0.0265,
          0.5041, 0.6166],
         [0.4891, 0.1579, 0.0265, 0.2191, 0.5626, 0.9797, 0.9507, 0.5394,
          0.0937, 0.3389],
         [0.3268, 0.8648, 0.3931, 0.0846, 0.3540, 0.7184, 0.7297, 0.9259,
          0.7639, 0.3943]],

        [[0.7751, 0.0392, 0.2076, 0.2525, 0.1707, 0.5629, 0.8782, 0.4542,
          0.8991, 0.4684],
         [0.7847, 0.8519, 0.6794, 0.6619, 0.4287, 0.0022, 0.5121, 0.9886,
          0.5650, 0.0018],
         [0.4407, 0.5124, 0.4109, 0.7643, 0.5718, 0.9197, 0.3073, 0.7319,
          0.0625, 0.1921]],

        [[0.7653, 0.2033, 0.9746, 0.5036, 0.0085, 0.2511, 0.3767, 0.2880,
          0.1173, 0.8549],
         [0.5013, 0.3789, 0.6818, 0.6572, 0.8783, 0.2324, 0.2101, 0.5186,
          0.9100, 0.6705],
         [0.7169, 0.8573, 0.0490, 0.7078, 0.7896, 0.6839, 0.2321, 0.6763,
          0.7682, 0.2440]],

        [[0.8551, 0.8265, 0.5504, 0.6530, 0.9327, 0.4980, 0.7012, 0.9977,
          0

In [60]:
length = ([[1],
        [2],
        [3],
        [4]])

length = torch.LongTensor(4)
length.shape

torch.Size([4])

In [65]:
target = ([[[1, 0],
         [1, 2],
         [1, 3]],

        [[2, 2],
         [5, 4],
         [6, 8]],

        [[3, 9],
         [4, 6],
         [6, 3]],

        [[4, 6],
         [2, 3],
         [7, 9]]])
target = torch.LongTensor(target)

In [70]:
# logits_flat: (batch * max_len, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
#print(logits_flat)

# log_probs_flat: (batch * max_len, num_classes)
log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
print("== log probs flat == \n", log_probs_flat)

# target_flat: (batch * max_len, 1)
#target_flat = target.view(-1, 1)
target_flat = target.view(-1, nb_labels)
print("== target flat == \n", target_flat)

# losses_flat: (batch * max_len, 1)
# dimensions need to be the same size except for the dim you are gathering on
# gathers from log probs the target index (row dim)
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
print("== losses flat == \n", losses_flat)

# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
print("== losses == \n", losses)

== log probs flat == 
 tensor([[-2.2134, -2.4105, -2.3302, -2.0327, -2.4061, -1.9176, -2.5788, -2.8337,
         -2.3561, -2.2436],
        [-2.3019, -2.6331, -2.7645, -2.5719, -2.2284, -1.8113, -1.8402, -2.2516,
         -2.6972, -2.4521],
        [-2.5656, -2.0275, -2.4992, -2.8078, -2.5383, -2.1739, -2.1627, -1.9664,
         -2.1284, -2.4980],
        [-2.0407, -2.7767, -2.6082, -2.5633, -2.6451, -2.2529, -1.9376, -2.3617,
         -1.9167, -2.3474],
        [-2.1113, -2.0441, -2.2166, -2.2341, -2.4673, -2.8939, -2.3840, -1.9074,
         -2.3310, -2.8943],
        [-2.3851, -2.3134, -2.4148, -2.0615, -2.2540, -1.9060, -2.5184, -2.0939,
         -2.7632, -2.6337],
        [-2.0222, -2.5843, -1.8129, -2.2839, -2.7790, -2.5364, -2.4108, -2.4996,
         -2.6702, -1.9326],
        [-2.3912, -2.5137, -2.2108, -2.2354, -2.0143, -2.6602, -2.6825, -2.3740,
         -1.9825, -2.2221],
        [-2.1916, -2.0512, -2.8594, -2.2007, -2.1188, -2.2246, -2.6764, -2.2321,
         -2.1403, -2.664

In [105]:
ymask = losses_flat.data.new(losses_flat.size()).zero_() # (all zero)
ymask = ymask.long()

print(ymask)

tensor([[0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0]])


In [106]:
# zmask = ymask + target
# zmask

In [110]:
ymask.scatter_(0, target_flat, 7) # .view(-1,1)

tensor([[1, 7],
        [7, 1],
        [7, 7],
        [7, 7],
        [7, 7],
        [7, 0],
        [7, 7],
        [7, 0],
        [0, 7],
        [0, 7],
        [0, 0],
        [0, 0]])