In [1]:
# pytorch seq2seq考虑mask的loss计算
# https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/masked_cross_entropy.py

import torch
import torch.nn as nn
from torch.nn import functional
from torch.autograd import Variable
torch.manual_seed(224)

<torch._C.Generator at 0x17eecd73710>

In [27]:
def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
    batch_size = sequence_length.size(0)
    # seq_range = torch.range(0, max_len - 1).long()
    # 注意这里的修改
    seq_range = torch.arange(0, max_len).long()
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_range_expand = Variable(seq_range_expand)
    if sequence_length.is_cuda:
        seq_range_expand = seq_range_expand.cuda()
    seq_length_expand = (sequence_length.unsqueeze(1)
                         .expand_as(seq_range_expand))
    return seq_range_expand < seq_length_expand

#### 注意要修改的地方torch.range()：由于版本问题，torch.range()用法将不被支持，应当改为torch.arange()


In [28]:
# 注意torch.range()与torch.arange()区别
a = torch.range(0,3)
print (a)
b = torch.arange(0,3)
print (b)
# 根据区别，将原代码中的seq_range = torch.range(0, max_len - 1).long()改为seq_range = torch.arange(0, max_len ).long()

tensor([0., 1., 2., 3.])
tensor([0, 1, 2])


  


In [54]:
def masked_cross_entropy(logits, target, length):
    
    # length = Variable(torch.LongTensor(length)).cuda()
    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.
    Returns:
        loss: An average loss value masked by the length.
    """

    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = functional.log_softmax(logits_flat, dim=-1) # 先做softmax,再分别取log
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) # 用gather函数选出实际类别对应的预测概率(愿意参照https://blog.csdn.net/qq_22210253/article/details/85229988)
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())
    # mask: (batch, max_len)
    mask = sequence_mask(sequence_length=length, max_len=target.size(1))
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    return loss

In [55]:
logits = torch.rand((3,4,4))  # batch=3,max_len=4,num_classes=3 (实际类别数量是3,包含一个padding=0)
logits

tensor([[[0.6427, 0.4695, 0.8888, 0.7078],
         [0.2212, 0.2054, 0.7708, 0.3889],
         [0.3415, 0.7100, 0.3211, 0.7718],
         [0.7857, 0.7259, 0.7164, 0.8696]],

        [[0.8742, 0.9993, 0.8661, 0.0262],
         [0.9472, 0.0056, 0.5668, 0.9110],
         [0.3902, 0.2152, 0.2788, 0.0949],
         [0.6346, 0.7060, 0.5457, 0.6759]],

        [[0.9141, 0.7632, 0.8881, 0.0853],
         [0.4445, 0.4018, 0.6818, 0.7057],
         [0.5727, 0.5517, 0.6963, 0.9769],
         [0.1144, 0.6929, 0.6979, 0.2920]]])

In [56]:
target = torch.tensor([[1,2,3,3],[2,3,3,0],[2,1,0,0]]) # 令0做tag_padding,则如果实际由三类，就需要有0，1，2，3四个tag_id (也可以是其他的做padding)
target

tensor([[1, 2, 3, 3],
        [2, 3, 3, 0],
        [2, 1, 0, 0]])

In [57]:
length = torch.tensor([4,3,2])
length

tensor([4, 3, 2])

In [58]:
# 先看sequence_mask(sequence_length, max_len=None)函数的效果
mask = sequence_mask(sequence_length=length, max_len=target.size(1)) #max_len是当前batch中的最大长度
mask

tensor([[1, 1, 1, 1],
        [1, 1, 1, 0],
        [1, 1, 0, 0]], dtype=torch.uint8)

可以看到，sequence_mask输出为1的部分即是实际长度的部分，可以用得到的mask矩阵来做计算loss时的屏蔽处理

In [59]:
# 计算loss
loss = masked_cross_entropy(logits, target, length)
loss

tensor(1.3153)

In [60]:
# 另外一种想法是:既然pandding部分的误差会通过与mask相乘被过滤掉，那么不妨令target中的padding为0，这样不用按照num_classes+1来处理
# 在seq2seq预测结果是words时，一般都包括了padding,unk等，因此可以用上面的方法处理；
# 如果预测的是pos,个人认为应该用下面这种方法，保证最后一个全连接层输出的维度与实际的num_classes一致

In [61]:
new_logits = torch.rand((3,4,3))  # batch=3,max_len=4,num_classes=3 
new_logits

tensor([[[0.3551, 0.9598, 0.6522],
         [0.2716, 0.0676, 0.9010],
         [0.6095, 0.1420, 0.4506],
         [0.1210, 0.4030, 0.0129]],

        [[0.0405, 0.2025, 0.0784],
         [0.6600, 0.3458, 0.9125],
         [0.5765, 0.0901, 0.7497],
         [0.4754, 0.6096, 0.6145]],

        [[0.0646, 0.5801, 0.4453],
         [0.9615, 0.2805, 0.8449],
         [0.5224, 0.2518, 0.3558],
         [0.8679, 0.1321, 0.9071]]])

In [62]:
new_target = torch.tensor([[0,1,2,2],[1,2,2,3],[1,0,3,3]]) # 3代表padding占位
new_target

tensor([[0, 1, 2, 2],
        [1, 2, 2, 3],
        [1, 0, 3, 3]])

In [63]:
new_length = torch.tensor([4,3,2])
new_length

tensor([4, 3, 2])

In [64]:
def new_masked_cross_entropy(logits, target, length):
    
    # length = Variable(torch.LongTensor(length)).cuda()
    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.
    Returns:
        loss: An average loss value masked by the length.
    """
    # mask: (batch, max_len)
    mask = sequence_mask(sequence_length=length, max_len=target.size(1))
    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = functional.log_softmax(logits_flat, dim=-1) # 这里添加了维度
    target_mask = target * mask.long() # 先把padding位屏蔽为0，使得gather时不会越界
    print ("target_mask:", target_mask)
    # target_flat: (batch * max_len, 1)
    target_mask_flat = target_mask.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_mask_flat) # 
    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size()) 
    losses = losses * mask.float()
    print ("losses: ", losses)
    loss = losses.sum() / length.float().sum()
    return loss

In [65]:
new_loss = new_masked_cross_entropy(new_logits, new_target, length)
new_loss

target_mask: tensor([[0, 1, 2, 2],
        [1, 2, 2, 0],
        [1, 0, 0, 0]])
losses:  tensor([[1.4295, 1.5101, 1.0670, 1.2785],
        [1.0057, 0.8520, 0.8578, 0.0000],
        [0.9047, 0.8738, 0.0000, 0.0000]])


tensor(1.0866)

In [66]:
# 验证
val_logits = new_logits.view(-1, new_logits.size(-1))
val_logits

tensor([[0.3551, 0.9598, 0.6522],
        [0.2716, 0.0676, 0.9010],
        [0.6095, 0.1420, 0.4506],
        [0.1210, 0.4030, 0.0129],
        [0.0405, 0.2025, 0.0784],
        [0.6600, 0.3458, 0.9125],
        [0.5765, 0.0901, 0.7497],
        [0.4754, 0.6096, 0.6145],
        [0.0646, 0.5801, 0.4453],
        [0.9615, 0.2805, 0.8449],
        [0.5224, 0.2518, 0.3558],
        [0.8679, 0.1321, 0.9071]])

In [67]:
val_logits = torch.tensor([[0.3551, 0.9598, 0.6522],
        [0.2716, 0.0676, 0.9010],
        [0.6095, 0.1420, 0.4506],
        [0.1210, 0.4030, 0.0129],
        [0.0405, 0.2025, 0.0784],
        [0.6600, 0.3458, 0.9125],
        [0.5765, 0.0901, 0.7497],
        [0.0646, 0.5801, 0.4453],
        [0.9615, 0.2805, 0.8449]]) # 删掉padding部分的结果

In [68]:
val_targets = torch.tensor([0, 1, 2, 2, 1, 2, 2, 1, 0]) # 去掉padding后的target

In [69]:
loss = nn.CrossEntropyLoss()
loss_value = loss(val_logits, val_targets)
loss_value

tensor(1.0866)

In [None]:
# 可以看到[Out]104和[Out]108的结果是一致的