This notebook gives a shallow comparaison the InverseSigmoid loss and the standard CrossEntropy loss, in terms of convergence speed.

In [1]:
''' Implementation of loss function that's similar to the inverse of the sigmoid function.
It could be used instead of CrossEntropy loss function.
'''

from typing import Optional
from numpy import log
import torch
from torch import nn
from torch.nn.functional import nll_loss, softmax

@torch.jit.script
def invsig_loss(y:torch.Tensor, inds:torch.Tensor, alpha:float, beta:float, gamma:float, reduction:str):
    y = softmax(y, dim=1)
    y = - nll_loss(y, inds, reduction='none')
    y = torch.log(1 / (alpha * y + beta) - 1) + gamma
    if reduction == 'sum':
        return y.sum()
    elif reduction == 'mean':
        return y.mean()
    return y


class InverseSigmoid(nn.Module):

    def __init__(self, alpha: float = .9, beta: float = .043, reduction:Optional[str]='mean'):
        ''' Initialisation of the loss function.
            The formula is :
                loss = torch.log(1 / (alpha * y + beta) - 1) + gamma
            where: gamma is a float that makes sure the loss is always positive, with loss(1)==0.
        Args: 
            - alpha, beta (floats): Two positive floats used to control: gradient step near 0 and 1.
                                    They need to be set carefully; always make sure that: 'alpha + beta < 1.'
            - reduction (string, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. 
                                            'none': no reduction will be applied, 
                                            'mean': the mean of the output is taken, 
                                            'sum': the output will be summed.
        '''
        super().__init__()

        assert alpha>0 and beta>0 and alpha+beta<1, f'Make sure: {alpha=}>0, {beta=}>0, and {alpha+beta=}<1 '

        self.reduction = reduction
        self.alpha = alpha
        self.beta = beta
        # Assert loss(1)==0
        self.gamma = -log(1/(alpha + beta)-1)

    def forward(self, y:torch.Tensor, inds:torch.Tensor):
        ''' Compute the inverse sigmoid loss function
        
        Args:
            - y: A Float tensor with a shape (batch, C, d1, d2, ..., dK), C: number of classes.
            - inds: A Long tensor of indices, with a shape (batch, d1, d2, ..., dK), 
                    where each entry is non-negative and smaller than C.
        
        Returns:
            - loss: depending on the reduction, the resulting tensor could be a scalar tensor, 
                    or a tensor with the same shape as 'inds' if 'reduction' is 'none'.
        '''
        return invsig_loss(y, inds, self.alpha, self.beta, self.gamma, self.reduction)

In [2]:
_a = torch.rand(4, 5)
b = torch.tensor([1, 2, 3, 0])
_a

tensor([[0.7790, 0.0273, 0.0389, 0.3824, 0.7264],
        [0.5642, 0.3020, 0.2234, 0.1934, 0.3828],
        [0.9071, 0.1188, 0.5804, 0.1271, 0.6016],
        [0.0436, 0.3242, 0.5156, 0.3085, 0.3335]])

In [3]:
loss1 = InverseSigmoid()
params1 = nn.Parameter(_a.clone().detach())

In [4]:
loss2 = nn.CrossEntropyLoss()
params2 = nn.Parameter(_a.clone().detach())

In [5]:
from tqdm.notebook import tqdm

def train(loss, params, epochs):
    opt = torch.optim.Adam([params], )
    with tqdm(total=epochs) as pbar:
        for i in range(epochs):
            opt.zero_grad()
            l = loss(params, b)
            l.backward()
            opt.step()
            if i%1000 == 0:
                print('loss =', l.item())
            pbar.update(1)
            pbar.set_postfix({'loss':l.item()})

In [6]:
epochs = 2000

In [7]:
train(loss1, params1, epochs)

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))

loss = 4.3429436683654785
loss = 2.548475503921509



In [8]:
train(loss2, params2, epochs)

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))

loss = 1.9074592590332031
loss = 0.677466094493866



In [9]:
loss2(params1, b)

tensor(0.1080, grad_fn=<NllLossBackward>)

In [10]:
loss1(params2, b)

tensor(1.7347, grad_fn=<MeanBackward0>)