In [1]:
# ref : https://stackoverflow.com/questions/55681502/label-smoothing-in-pytorch

In [2]:
import torch


# Class 별 샘플 수 
count_class_0 = 10000
count_class_1 = 400

# 전체 샘플 수 
total_count = count_class_0 + count_class_1

# 클래스별 가중치 계산 
weight_class_0 = total_count / (2* count_class_0)
weight_class_1 = total_count / (2* count_class_1)

weights = torch.tensor([weight_class_0,weight_class_1])
device = "cuda:0"
weights = weights.to(device) 

In [3]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules.loss import _WeightedLoss


class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1, weight = None):
        """if smoothing == 0, it's one-hot method
           if 0 < smoothing < 1, it's smooth method
        """
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.weight = weight
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        assert 0 <= self.smoothing < 1
        pred = pred.log_softmax(dim=self.dim)

        if self.weight is not None:
            pred = pred * self.weight.unsqueeze(0)   

        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [4]:
class SmoothCrossEntropyLoss(_WeightedLoss):
    def __init__(self, weight=None, reduction='mean', smoothing=0.0):
        super().__init__(weight=weight, reduction=reduction)
        self.smoothing = smoothing
        self.weight = weight
        self.reduction = reduction

    def k_one_hot(self, targets:torch.Tensor, n_classes:int, smoothing=0.0):
        with torch.no_grad():
            targets = torch.empty(size=(targets.size(0), n_classes),
                                  device=targets.device) \
                                  .fill_(smoothing /(n_classes-1)) \
                                  .scatter_(1, targets.data.unsqueeze(1), 1.-smoothing)
        return targets

    def reduce_loss(self, loss):
        return loss.mean() if self.reduction == 'mean' else loss.sum() \
        if self.reduction == 'sum' else loss

    def forward(self, inputs, targets):
        assert 0 <= self.smoothing < 1

        targets = self.k_one_hot(targets, inputs.size(-1), self.smoothing)
        log_preds = F.log_softmax(inputs, -1)

        if self.weight is not None:
            log_preds = log_preds * self.weight.unsqueeze(0)

        return self.reduce_loss(-(targets * log_preds).sum(dim=-1))

In [14]:
targets = torch.empty(size=(3,7)).fill_(0.5 / 6 ).scatter_(1, )

print(targets)

tensor([[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833],
        [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833],
        [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833]])


In [11]:
targets.shape

torch.Size([3, 7])

In [12]:
targets

tensor([[1.6929e-37, 0.0000e+00, 7.4146e-34, 0.0000e+00, 1.1210e-43, 0.0000e+00,
         8.9683e-44],
        [0.0000e+00, 7.4146e-34, 0.0000e+00, 7.0065e-45, 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 7.0065e-45, 0.0000e+00,
         1.4013e-45]])

In [6]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.modules.loss import _WeightedLoss


if __name__=="__main__":
    # 1. Devin Yang
    crit = LabelSmoothingLoss(classes=5, smoothing=0.5)
    predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0],
                                 [0, 0.9, 0.2, 0.2, 1], 
                                 [1, 0.2, 0.7, 0.9, 1]])
    v = crit(Variable(predict),
             Variable(torch.LongTensor([2, 1, 0])))
    print(v)

    # 2. Shital Shah
    crit = SmoothCrossEntropyLoss(smoothing=0.5)
    predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0],
                                 [0, 0.9, 0.2, 0.2, 1], 
                                 [1, 0.2, 0.7, 0.9, 1]])
    v = crit(Variable(predict),
             Variable(torch.LongTensor([2, 1, 0])))
    print(v)

tensor(1.5161)
tensor(1.5161)


In [7]:
Variable(torch.LongTensor([2, 1, 0]))

tensor([2, 1, 0])