In [1]:
import torch
import torch.nn as nn


In [40]:
class LS_CELoss(nn.Module):
    def __init__(self, ls_alpha=0.1, class_weights=[1.0, 1.0], device="cpu"):
        super().__init__()
        self.ls_alpha = ls_alpha
        self.ce_loss = nn.CrossEntropyLoss(reduction="none")
        self.class_weights = torch.tensor(class_weights).to(device)
        
    def forward(self, predicted_pos_logits, class_targets):
        expanded_logits = torch.vstack([-predicted_pos_logits, predicted_pos_logits]).transpose(0, 1)
        targets_probs = torch.ones_like(expanded_logits)*self.ls_alpha/2
        targets_probs[torch.arange(len(targets_probs)), targets] += (1-self.ls_alpha)
        weights = torch.gather(self.class_weights, 0, targets.long())
        loss = self.ce_loss(expanded_logits, targets_probs)
        loss = torch.mean(loss * weights)
        return loss

In [41]:
logits = torch.tensor([0.1, 2.2, 1.0, -1.0, -0.003])
targets = torch.tensor([1, 1, 1, 0, 0])

In [42]:
ls_alpha=0.1

In [43]:
ls_loss = LS_CELoss()
ls_loss(logits, targets)

tensor(0.3969)

In [14]:
expanded_logits = torch.vstack([-logits, logits]).transpose(0, 1)

In [15]:
expanded_logits

tensor([[-0.1000,  0.1000],
        [-2.2000,  2.2000],
        [-1.0000,  1.0000],
        [ 1.0000, -1.0000],
        [ 0.0030, -0.0030]])

In [22]:
targets_probs = torch.ones_like(expanded_logits)*ls_alpha/2
targets_probs[torch.arange(len(targets_probs)), targets] += (1-ls_alpha)
targets_probs

tensor([[0.0500, 0.9500],
        [0.0500, 0.9500],
        [0.0500, 0.9500],
        [0.9500, 0.0500],
        [0.9500, 0.0500]])

In [21]:
targets

tensor([1, 1, 1, 0, 0])