In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [37]:
class LabelSmoothingCrossEntropy(nn.Module):
    """ NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        logprobs = F.log_softmax(x, dim=-1)
        print('logprobs:', logprobs)
#         print('softmax:', torch.softmax(x, dim = 1))
#         print('confirm logprobs:', torch.log(torch.softmax(x, dim = 1)))
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        print(target.unsqueeze(1))
        print('nll_loss:', nll_loss)
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        print('smooth loss:', smooth_loss)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()


In [38]:
loss = LabelSmoothingCrossEntropy(smoothing = 0.2)

In [39]:
preds = torch.Tensor([[-9, 12], [0, 99], [50, 5], [0, 0], [999, -999]]).float()
targets = torch.Tensor([1, 1, 0, 1, 1]).long()

In [40]:
loss(preds, targets)

logprobs: tensor([[-2.1000e+01,  0.0000e+00],
        [-9.9000e+01,  0.0000e+00],
        [ 0.0000e+00, -4.5000e+01],
        [-6.9315e-01, -6.9315e-01],
        [ 0.0000e+00, -1.9980e+03]])
tensor([[1],
        [1],
        [0],
        [1],
        [1]])
nll_loss: tensor([[-0.0000e+00],
        [-0.0000e+00],
        [-0.0000e+00],
        [6.9315e-01],
        [1.9980e+03]])
smooth loss: tensor([1.0500e+01, 4.9500e+01, 2.2500e+01, 6.9315e-01, 9.9900e+02])


tensor(363.0786)