In [1]:
%config Completer.use_jedi = False

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 1 .Binary cross entropy loss is used in case of multi-label classification problem

For ex: classify animals present in the picture. A single picture can have multiple animals(assuming an image can have one or more animals out of 4 possible animals).

In [36]:
# assuming last layer of our neural network outputs logit/ inverse sigmoid values
n_samples, n_classes = 10, 4
x = torch.randn(n_samples, n_classes)
x

tensor([[-0.8146, -1.0212, -0.4949, -0.5923],
        [ 0.1543,  0.4408, -0.1483, -2.3184],
        [-0.3980,  1.0805, -1.7809,  1.5080],
        [ 0.3094, -0.5003,  1.0350,  1.6896],
        [-0.0045,  1.6668,  0.1539, -1.0603],
        [-0.5727,  0.0836,  0.3999,  1.9892],
        [ 0.1729,  1.0514,  0.0075, -0.0774],
        [ 0.6427,  0.5742,  0.5867, -0.0188],
        [-0.9143,  1.4840, -0.9109, -0.5291],
        [-0.8051,  0.5158, -0.7129,  0.2196]])

In [37]:
# creating arbitrary true labels 
true_one_hot = torch.randint(2, 
                     size=(n_samples, n_classes),
                       dtype=torch.float)
true_one_hot

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [1., 1., 0., 1.],
        [0., 1., 0., 1.],
        [1., 1., 1., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 1., 1.],
        [1., 0., 0., 1.],
        [0., 0., 0., 1.]])

In [48]:
def sigmoid(x: torch.Tensor) -> torch.Tensor:
    return 1/(1 + torch.exp(-x))

def binary_cross_entropy_loss(
    pred: torch.Tensor,
    true: torch.Tensor, 
    eps: float=1e-10) -> torch.Tensor:
    """
    Compute Binary Cross Entropy between the prediction and true outputs.
    
    Parameters
    ----------
    pred: (n, c)
        contains probabilities across c classes and n samples.
        
    true: (n, c)
        one-hot-encoded values.
    """
    pred = pred + eps
    loss = true * torch.log(pred) + (1-true) * torch.log(1 - pred)
    return -torch.mean(loss)

In [49]:
# Note: here we are using sigmoid and not softmax
pred_prob = sigmoid(x) 
pred_prob

tensor([[0.3069, 0.2648, 0.3787, 0.3561],
        [0.5385, 0.6084, 0.4630, 0.0896],
        [0.4018, 0.7466, 0.1442, 0.8188],
        [0.5767, 0.3775, 0.7379, 0.8442],
        [0.4989, 0.8411, 0.5384, 0.2573],
        [0.3606, 0.5209, 0.5987, 0.8797],
        [0.5431, 0.7410, 0.5019, 0.4807],
        [0.6554, 0.6397, 0.6426, 0.4953],
        [0.2861, 0.8152, 0.2868, 0.3707],
        [0.3089, 0.6262, 0.3290, 0.5547]])

In [50]:
binary_cross_entropy_loss(pred=pred_prob, true=true_one_hot)

tensor(0.7942)

In [51]:
F.binary_cross_entropy(input=pred_prob, target=true_one_hot)

tensor(0.7942)

In [52]:
def binary_cross_entropy_loss_with_logits(
    pred: torch.Tensor,
    true: torch.Tensor, 
    eps: float=1e-10) -> torch.Tensor:
    """
    Compute Binary Cross Entropy between the prediction and true outputs.
    
    Parameters
    ----------
    pred: (n, c)
        contains logit values across c classes and n samples.
        
    true: (n, c)
        one-hot-encoded values.
    """
    pred = torch.sigmoid(pred) + eps
    loss = true * torch.log(pred) + (1 - true) * torch.log(1 - pred)
    return -torch.mean(loss)

In [55]:
binary_cross_entropy_loss_with_logits( pred=x, true=true_one_hot)

tensor(0.7942)

In [56]:
# Note: built in function assumes logit/ inverse sigmoid values across all classes
F.binary_cross_entropy_with_logits(input=x, target=true_one_hot)

tensor(0.7942)

# 2. Categorical Cross Entropy loss - used in case of multi-class classification where every sample/ observation can belong to single class 

In [57]:
# Defining true labels, one hot encoded
true = torch.randint(n_classes, 
                     size=(n_samples,),
                     dtype=torch.long)
true

tensor([1, 0, 3, 1, 3, 1, 3, 2, 3, 2])

In [58]:
# not used, just for convenience
true_one_hot = torch.zeros((n_samples, n_classes))
true_one_hot[range(true_one_hot.shape[0]), true] = 1
true_one_hot

tensor([[0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.]])

In [59]:
# not used, just for convenience
pred_prob = torch.softmax(x, axis=1)
pred_prob
# Note: each row sum to 1

tensor([[0.2253, 0.1832, 0.3101, 0.2814],
        [0.3170, 0.4221, 0.2342, 0.0267],
        [0.0809, 0.3548, 0.0203, 0.5441],
        [0.1336, 0.0594, 0.2759, 0.5310],
        [0.1276, 0.6786, 0.1495, 0.0444],
        [0.0540, 0.1040, 0.1427, 0.6993],
        [0.1987, 0.4783, 0.1684, 0.1547],
        [0.2945, 0.2750, 0.2785, 0.1520],
        [0.0691, 0.7601, 0.0693, 0.1015],
        [0.1159, 0.4342, 0.1271, 0.3229]])

In [60]:
def categorical_cross_entropy_loss(
    pred: torch.Tensor, 
    true: torch.Tensor,
    eps: float=1e-10):
    """
    Categorical cross entropy loss computes loss only across true labels.
    
    Parameters
    ----------
    pred: (n, c)
        n: number of samples
        
        c: number of classes
        
        pred contains raw unnormalized values.
    
    true: (n, )
        true contains integer values.
    """
    pred = torch.softmax(pred, axis=1)
    log_loss_across_true_labels = torch.log(pred[range(pred.shape[0]),true])
    return -torch.mean(log_loss_across_true_labels)

In [61]:
categorical_cross_entropy_loss(pred=x, true=true)

tensor(1.9151)

In [62]:
# Using built in categorical cross entropy loss function
# Note: built in function takes up raw unnormalized scores with 
# true labels passed as integer labels
torch.nn.CrossEntropyLoss()(x, true)

tensor(1.9151)

In [63]:
F.cross_entropy(x, true)

tensor(1.9151)

Added ability to ignore computing loss along indexes

In [64]:
def categorical_cross_entropy_loss(
    pred: torch.Tensor, 
    true: torch.Tensor,
    ignore_index: int=-100):
    """
    Categorical cross entropy loss computes loss only across true labels.
    
    Parameters
    ----------
    pred: (n, c)
        n: number of samples
        
        c: number of classes
        
        pred contains raw unnormalized values.
    
    true: (n, )
        true contains integer values.
    
    ignore_idx: true value across which loss is not computed.
        Used to decrease importance across trivial true values.
    """
    idxs_to_keep = torch.nonzero(true!=ignore_index).squeeze()
    true = true.index_select(0, idxs_to_keep) 
    pred = pred.index_select(0, idxs_to_keep)
    pred = torch.log_softmax(pred, axis=1)
    return -torch.mean(pred[range(true.shape[0]), true])

In [65]:
categorical_cross_entropy_loss(pred=x, true=true, ignore_index=2)

tensor(1.9762)

In [66]:
# Using built in categorical cross entropy loss function
# Note: built in function takes up raw unnormalized scores with 
# true labels passed as integer labels
torch.nn.CrossEntropyLoss(ignore_index=2)(x, true)

tensor(1.9762)

In [67]:
F.cross_entropy(x, true, ignore_index=2)

tensor(1.9762)