In [22]:
import copy
import json
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
import numpy as np
import math
import numpy
import torch
from torch.autograd import Variable

In [None]:
def _sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
        
    batch_size = sequence_length.size(0)
    
    seq_range = torch.range(0, max_len - 1).long()
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_range_expand = Variable(seq_range_expand)
    if sequence_length.is_cuda:
        seq_range_expand = seq_range_expand.cuda()
    seq_length_expand = (sequence_length.unsqueeze(1)
                         .expand_as(seq_range_expand))
    return seq_range_expand < seq_length_expand


def compute_loss(logits, target, length):
    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.
    Returns:
        loss: An average loss value masked by the length.
    """

    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())
    # mask: (batch, max_len)
    mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    return loss

In [None]:
# logits: (batch, seq_len, num_classes)
# target: (batch, seq_len)

batch_size = 4
nb_classes = 2
seq_len = 3

logits = torch.rand(batch_size, seq_len, nb_classes)
target = torch.rand(batch_size, seq_len)
target = target.long()

print(logits)
print(target)

In [None]:
length = ([[1],
        [2],
        [3],
        [4]])

length = torch.LongTensor(4)
length.shape

In [None]:
#loss = compute_loss(logits, target, length)

In [None]:
# logits_flat: (batch * max_len, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
print(logits_flat)

# log_probs_flat: (batch * max_len, num_classes)
log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
print(log_probs_flat)

# target_flat: (batch * max_len, 1)
target_flat = target.view(-1, 1)
print(target_flat)

# losses_flat: (batch * max_len, 1)
# dimensions need to be the same size except for the dim you are gathering on
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
print(losses_flat)

# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
print(losses)

In [None]:
# mask: (batch, max_len)
mask = _sequence_mask(sequence_length=length, max_len=target.size(1))
print(mask)

losses = losses * mask.float()
print(losses)

loss = losses.sum() / length.float().sum()
print(loss)

## Type 2 - 3d target

In [3]:
# logits: (batch, seq_len, num_classes)
# target: (batch, seq_len)
# length: (batch,)
batch_size = 4
nb_classes = 10
nb_labels = 3
seq_len = 5

logits = torch.rand(batch_size, seq_len, nb_classes)
# notice target has a third dimension of nb_labels
target = torch.rand(batch_size, seq_len, nb_labels)
target = target.long()

print(logits.shape)
print(target.shape)

torch.Size([4, 5, 10])
torch.Size([4, 5, 3])


In [11]:
# create a target tensor with some padded elements
target = ([[[1, 0, 0],
         [1, 2, 0],
         [1, 4, 0],
         [1, 3, 0],
         [1, 3, 0]],

        [[2, 2, 0],
         [5, 4, 0],
         [1, 8, 0],
         [3, 1, 0],
         [6, 8, 0]],

        [[3, 9, 0],
         [4, 6, 0],
         [4, 3, 0],
         [3, 1, 0],
         [6, 3, 0]],

        [[4, 6, 0],
         [2, 3, 3],
         [1, 2, 9],
         [1, 2, 6],
         [7, 9, 8]]])

target = torch.LongTensor(target)
print(target.shape)

torch.Size([4, 5, 3])


In [12]:
# logits_flat: (batch * max_len, num_classes)
logits_flat = logits.view(-1, logits.size(-1))
#print("== logits flat == \n", logits_flat)

# log_probs_flat: (batch * max_len, num_classes)
log_probs_flat = torch.nn.functional.log_softmax(logits_flat, dim=-1)
print("== log probs flat == \n", log_probs_flat)

# target_flat: (batch * max_len, 1)
#target_flat = target.view(-1, 1)
target_flat = target.view(-1, nb_labels)
print("== target flat == \n", target_flat)

# losses_flat: (batch * max_len, 1)
# dimensions need to be the same size except for the dim you are gathering on
# gathers from log probs the target index (row dim)
losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
print("== losses flat == \n", losses_flat)

# losses: (batch, max_len)
losses = losses_flat.view(*target.size())
print("== losses == \n", losses)

== log probs flat == 
 tensor([[-1.8218, -2.7636, -1.9159, -2.1004, -2.6394, -2.0386, -2.6950, -2.8105,
         -2.4041, -2.4545],
        [-2.5887, -2.2012, -2.4453, -1.9583, -2.5401, -2.7374, -1.9999, -2.0180,
         -2.1662, -2.8145],
        [-1.9865, -2.4331, -2.4818, -1.9130, -2.3507, -2.5870, -2.3594, -2.0848,
         -2.5818, -2.5406],
        [-2.5858, -2.6240, -1.9591, -2.5773, -2.2977, -2.3896, -1.9731, -1.8685,
         -2.5250, -2.6653],
        [-2.4359, -2.2461, -2.4838, -2.6442, -2.4852, -2.5410, -2.4728, -1.7417,
         -1.9124, -2.4914],
        [-2.2621, -2.5824, -1.9419, -2.1887, -2.3836, -2.0864, -2.5794, -2.4914,
         -2.0394, -2.8182],
        [-2.7964, -2.0027, -1.8743, -2.1708, -2.3545, -2.3196, -2.4992, -2.8081,
         -2.4870, -2.1406],
        [-2.2576, -2.3424, -2.3087, -2.0824, -2.5972, -2.2311, -2.5954, -2.3671,
         -2.0642, -2.3219],
        [-2.8438, -2.0766, -2.4358, -2.4724, -1.9623, -2.0347, -2.7247, -2.5042,
         -2.1648, -2.196

In [16]:
# padding
# 2d padding has shape (batch_size, seq_len), need to obtain the last dimension
weights = torch.zeros(batch_size, seq_len)
print(weights)
print(weights.shape)

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
torch.Size([4, 5])


In [None]:
# torch nonzero on last element of targets?

In [20]:
torch.nonzero(target.data).size(-1)

3

In [38]:
a = np.random.randint(6,size=(4,5,3))
print(a)

idx = np.nonzero(a)
print(idx)
#a[idx] = 0

[[[1 1 2]
  [0 3 4]
  [4 1 1]
  [5 0 4]
  [1 3 4]]

 [[3 0 2]
  [2 2 5]
  [3 0 1]
  [5 2 4]
  [2 0 2]]

 [[4 2 5]
  [2 1 3]
  [5 1 5]
  [2 3 5]
  [4 4 0]]

 [[5 3 1]
  [5 0 3]
  [0 5 5]
  [3 5 4]
  [1 5 4]]]
(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3]), array([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 0, 0, 1, 1, 1, 2, 2, 3, 3,
       3, 4, 4, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 0, 0, 0, 1, 1,
       2, 2, 3, 3, 3, 4, 4, 4]), array([0, 1, 2, 1, 2, 0, 1, 2, 0, 2, 0, 1, 2, 0, 2, 0, 1, 2, 0, 2, 0, 1,
       2, 0, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 0, 1, 2, 0, 2,
       1, 2, 0, 1, 2, 0, 1, 2]))


In [10]:
ymask = losses_flat.data.new(losses_flat.size()).zero_() # (all zero)
ymask = ymask.long()

print(ymask)

tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])


In [None]:
# zmask = ymask + target
# zmask

In [None]:
ymask.scatter_(0, target_flat, 7) # .view(-1,1)