<a href="https://colab.research.google.com/github/ferdouszislam/pytorch-practice/blob/main/cross_entropy_from_label_probabilities.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

### Cross Entropy

In [2]:
# Cross Entropy Loss implementation was modified from- <https://discuss.pytorch.org/t/how-to-write-custom-crossentropyloss/58072> 

def log_softmax(x):
  return x - torch.logsumexp(x,dim=1, keepdim=True)

def cross_entropy(outputs, targets):
  num_examples = targets.shape[0]
  cross_entropies = targets*log_softmax(outputs)
  return - torch.sum(cross_entropies) / num_examples

In [3]:
outputs = torch.tensor([[300.0, 100.0, 400.0, 100.0, 100.0],
                        [100.0, 200.0, 300.0, 200.0, 200.0],
                        [300.0, 100.0, 400.0, 100.0, 100.0]], requires_grad=True)

labels = torch.tensor([3, 2, 1])


label_one_hot = torch.tensor([[0, 0, 0, 1, 0], 
                              [0, 0, 1, 0, 0], 
                              [0, 1, 0, 0, 0]])

label_probs = torch.tensor([[0.1, 0.2, 0.1, 0.4, 0.2], 
                            [0.2, 0.2, 0.4, 0.1, 0.1], 
                            [0.1, 0.6, 0.1, 0.1, 0.1]])

In [4]:
# get labels from probs
_, labels_from_label_probs = torch.max(label_probs, 1)
labels_from_label_probs

tensor([3, 2, 1])

In [5]:
built_in_ce = nn.CrossEntropyLoss()
built_in_ce(outputs, labels)

tensor(200., grad_fn=<NllLossBackward>)

In [6]:
cross_entropy(outputs, label_one_hot)

tensor(200., grad_fn=<DivBackward0>)

In [7]:
cross_entropy(outputs, label_probs)

tensor(193.3333, grad_fn=<DivBackward0>)

### Cross Entropy with class weights

In [8]:
def cross_entropy_with_class_weights(outputs, targets, class_weights = [1.0, 1.0, 1.0, 1.0, 1.0]):
  _, labels = torch.max(targets, 1)
  labels = labels.float()
  print(labels.type())
  weights_tensor = torch.ones_like(labels)
  for curr_class in range(5):
    curr_class_tensor = torch.full(labels.size(), curr_class)
    curr_class_weight = class_weights[curr_class]
    curr_class_weight_tensor = torch.full(labels.size(), curr_class_weight)
    weights_tensor = torch.where(labels==curr_class_tensor, curr_class_weight_tensor, weights_tensor)

  num_examples = targets.shape[0]
  cross_entropies = targets*log_softmax(outputs)
  # TODO: verify 
  # collapse cross_entropies of shape (batch_size, 5) to shape (batch_size, 1) by taking their sum
  # then multiply with weights_tensor
  cross_entropies = torch.sum(cross_entropies, dim=1) * weights_tensor

  return - torch.sum(cross_entropies) / num_examples

In [9]:
cross_entropy(outputs, label_one_hot)

tensor(200., grad_fn=<DivBackward0>)

In [10]:
cross_entropy_with_class_weights(outputs, label_probs)

torch.FloatTensor


tensor(193.3333, grad_fn=<DivBackward0>)