<a href="https://colab.research.google.com/github/ferdouszislam/Algorithms/blob/main/cross_entropy_false_error.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

In [3]:
# Cross Entropy Loss implementation taken from- <https://discuss.pytorch.org/t/how-to-write-custom-crossentropyloss/58072>
# modified to create cross_entropy_false_error() function 

def log_softmax(x):
  return x - torch.logsumexp(x,dim=1, keepdim=True)

def cross_entropy(outputs, targets):
  num_examples = targets.shape[0]
  batch_size = outputs.shape[0]
  outputs = log_softmax(outputs)
  outputs = outputs[range(batch_size), targets]

  return - torch.sum(outputs)/num_examples

def cross_entropy_false_error(outputs, targets):
  #print(outputs, targets)

  num_examples = targets.shape[0]
  batch_size = outputs.shape[0]
  outputs = log_softmax(outputs)
  outputs = outputs[range(batch_size), targets]
  outputs = -1 * outputs

  all_zeroes = torch.zeros_like(targets)
  all_ones = torch.ones_like(targets)

  label_tensor = torch.full(targets.size(), 0)
  class0_pos_mul_tensors = torch.where(targets == label_tensor, all_ones, all_zeroes)
  class0_pos_cnt = targets.eq(label_tensor).sum().item()
  class0_neg_mul_tensors = torch.where(targets != label_tensor, all_ones, all_zeroes)
  class0_neg_cnt = targets.ne(label_tensor).sum().item()

  label_tensor = torch.full(targets.size(), 1)
  class1_pos_mul_tensors = torch.where(targets == label_tensor, all_ones, all_zeroes)
  class1_pos_cnt = targets.eq(label_tensor).sum().item()
  class1_neg_mul_tensors = torch.where(targets != label_tensor, all_ones, all_zeroes)
  class1_neg_cnt = targets.ne(label_tensor).sum().item()

  label_tensor = torch.full(targets.size(), 2)
  class2_pos_mul_tensors = torch.where(targets == label_tensor, all_ones, all_zeroes)
  class2_pos_cnt = targets.eq(label_tensor).sum().item()
  class2_neg_mul_tensors = torch.where(targets != label_tensor, all_ones, all_zeroes)
  class2_neg_cnt = targets.ne(label_tensor).sum().item()

  label_tensor = torch.full(targets.size(), 3)
  class3_pos_mul_tensors = torch.where(targets == label_tensor, all_ones, all_zeroes)
  class3_pos_cnt = targets.eq(label_tensor).sum().item()
  class3_neg_mul_tensors = torch.where(targets != label_tensor, all_ones, all_zeroes)
  class3_neg_cnt = targets.ne(label_tensor).sum().item()

  #print(targets)

  label_tensor = torch.full(targets.size(), 4)
  class4_pos_mul_tensors = torch.where(targets == label_tensor, all_ones, all_zeroes)
  class4_pos_cnt = targets.eq(label_tensor).sum().item()
  class4_neg_mul_tensors = torch.where(targets != label_tensor, all_ones, all_zeroes)
  class4_neg_cnt = targets.ne(label_tensor).sum().item()

  # print('\nclass 4:')
  # print(label_tensor)
  # print(class4_pos_mul_tensors)
  # print(class4_pos_cnt)
  # print(class4_neg_mul_tensors)
  # print(class4_neg_cnt)

  error_sum = 0.0

  if class0_pos_cnt != 0: 
    error_sum += (torch.sum(outputs*class0_pos_mul_tensors)/class0_pos_cnt)
  if class0_neg_cnt != 0:
    error_sum += (torch.sum(outputs*class0_neg_mul_tensors)/class0_neg_cnt)

  if class1_pos_cnt != 0: 
    error_sum += (torch.sum(outputs*class1_pos_mul_tensors)/class1_pos_cnt)
  if class1_neg_cnt != 0:
    error_sum += (torch.sum(outputs*class1_neg_mul_tensors)/class1_neg_cnt)

  if class2_pos_cnt != 0: 
    error_sum += (torch.sum(outputs*class2_pos_mul_tensors)/class2_pos_cnt)
  if class2_neg_cnt != 0:
    error_sum += (torch.sum(outputs*class2_neg_mul_tensors)/class2_neg_cnt)

  if class3_pos_cnt != 0: 
    error_sum += (torch.sum(outputs*class3_pos_mul_tensors)/class3_pos_cnt)
  if class3_neg_cnt != 0:
    error_sum += (torch.sum(outputs*class3_neg_mul_tensors)/class3_neg_cnt)

  if class4_pos_cnt != 0: 
    error_sum += (torch.sum(outputs*class4_pos_mul_tensors)/class4_pos_cnt)
  if class4_neg_cnt != 0:
    error_sum += (torch.sum(outputs*class4_neg_mul_tensors)/class4_neg_cnt)

  return error_sum / 5

In [4]:
outputs = torch.tensor([[300.0, 100.0, 400.0, 100.0, 100.0],
                        [100.0, 200.0, 300.0, 200.0, 200.0],
                        [300.0, 100.0, 400.0, 100.0, 100.0],
                        [50.0, 150.0, 40.0, 300.0, 200.0],
                        [300.0, 100.0, 400.0, 100.0, 156.0],
                        [205.0, 41.0, 400.0, 62.0, 100.0],
                        [30.0, 15.0, 400.0, 600.0, 215.0],
                        [300.0, 100.0, 400.0, 100.0, 520.0]], requires_grad=True)
labels = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2])

In [5]:
built_in_ce = nn.CrossEntropyLoss()
built_in_ce(outputs, labels)

tensor(168., grad_fn=<NllLossBackward>)

In [6]:
cross_entropy(outputs, labels)

tensor(168., grad_fn=<DivBackward0>)

In [7]:
cross_entropy_false_error(outputs, labels)

tensor(326.3619, grad_fn=<DivBackward0>)

In [8]:
# label_tensor = torch.full(labels.size(), 4)
# #labels.eq(label_tensor).sum()

# class4_pos_mul_tensors = torch.where(labels == label_tensor, torch.ones_like(labels), torch.zeros_like(labels))
# class4_neg_mul_tensors = torch.where(labels != label_tensor, torch.ones_like(labels), torch.zeros_like(labels))

# print(labels, class4_pos_mul_tensors, class4_neg_mul_tensors)