# Cross Entropy Loss with Label Smoothing

This notebook implements cross entropy loss with label smoothing. The label smoothing loss is tested with dummy input (but in the same format that is expected for Bengali.AI) and tested against cross entropy without label smoothing to confirm that the results are the same when smoothing is set to zero.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class CrossEntropySumLoss(nn.Module):
    """Neural network module to compute sum of cross entropy losses.

    Attributes:
        device = [torch.device] device to compute the loss on
    """

    def __init__(self, device):
        """Initializes the loss module

        Args:
            device = [torch.device] device to compute the loss on
        """
        super(CrossEntropySumLoss, self).__init__()
        self.device = device

    def forward(self, input, target):
        """Sums cross entropy losses of given predictions and targets.

        Args:
            input  = [tuple] sequence of tensors of (raw) predictions
            target = [tuple] sequence of tensors of targets

        Returns [torch.Tensor]:
            The grapheme_root, vowel_dacritic, consonant_diacritic,
            and combined losses given the predictions and targets.
        """
        losses = []
        for y, t in zip(input, target):
#             t = t.to(self.device)
            loss = F.cross_entropy(y, t)
            losses.append(loss)

        losses.append(sum(losses))
        return torch.stack(losses)

In [3]:
class LabelSmoothingLoss(nn.Module):
    """
    Adapted from:
    https://github.com/pytorch/pytorch/issues/7455#issuecomment-513062631
    
    Cross entropy loss with label smoothing.    
    When `smoothing=0.0`, the loss will be equivalent to 
    standard cross entropy loss (`F.cross_entropy`).
    """
    def __init__(self, classes, smoothing=0.0, dim=-1):
        """
        Args:
            classes   = [tuple] number of classes for grapheme_root, 
                        vowel_diacritic, and consonant_diacritic
            smoothing = [float] controls degree of smoothing
            dim       = [int] dimension to compute the loss over
        """
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing # alpha
        self.classes =  classes # (graph, vowel, consonant) 
        self.dim = dim

    def forward(self, pred, target):
        """
        Args:
            pred   = [tuple] sequence of tensors of (raw) predictions
            target = [tuple] sequence of tensors of targets
            
        Returns [torch.Tensor]:
            The grapheme_root, vowel_dacritic, consonant_diacritic,
            and combined losses given the predictions and targets.
        """
        losses = []
        for y, t, cls in zip(pred, target, self.classes):
            y = y.log_softmax(dim=self.dim)  
            with torch.no_grad():
                true_dist = torch.zeros_like(y)
                true_dist.fill_(self.smoothing / (cls - 1))
                true_dist.scatter_(1, t.data.unsqueeze(1), self.confidence)
                losses.append( torch.mean(torch.sum(-true_dist * y, dim=self.dim)) )
        losses.append(sum(losses))
        return torch.stack(losses)

In [4]:
# Initalize label smoothing loss
classes = (3, 2, 4)
loss = LabelSmoothingLoss(classes, 0.2)

In [5]:
# Dummy input
g = torch.tensor( [ [0.2, 1.6, 0.5], [8.5, 0.3, 3.2] ] )
v = torch.tensor( [ [2.1, 0.2, 0.1, 0.8], [1.5, 0.4, 4.2, 0.2] ] )
c = torch.tensor( [ [0.5, 3.2], [2.1, 1.1] ] )

gt = torch.tensor( [1, 0] )
vt = torch.tensor( [0, 2] )
ct = torch.tensor( [1, 0] )

y = (g, v, c)
t = (gt, vt, ct)
lsl = loss.forward(y, t)
print(lsl)

tensor([1.0312, 1.9518, 0.2873, 3.2703])


In [6]:
ce = CrossEntropySumLoss(None)
cel = ce.forward(y, t)
print(cel)

tensor([0.2312, 0.2727, 0.1892, 0.6931])


In [16]:
# when smoothing=0, label smoothing and cross entropy loss should be equal
assert torch.all( torch.eq(lsl, cel) ), 'Not equal.'