In [1]:
import torch
import torch.nn as nn

# First, check out CrossEntropyLoss()
f = torch.tensor([[-1., -3., 4.], [-3., 3., -1.]])
print(f.shape)
target = torch.tensor([0, 2])
print(target.shape)

criterion = torch.nn.CrossEntropyLoss()
print('Loss label encoded:',criterion(f, target)) 
# Note that, in the above, f has 3 logits, target has the label!

# However, the following will return an error!
#criterion(target, f)

# So: criterion(logits, true_label) is ok...

# how about one hot labeled target values
target2 = torch.tensor([[1., 0., 0.], [0., 0., 1.]])
print('Loss one-hot encoded:',criterion(f, target2))
print('Loss one-hot encoded reversed:',criterion(target2, f))

# So: be careful with the order. (Check out the cross-entropy formula.)

# Instead of CrossEntropyLoss, you can have LogSoftmax and then NLLoss
f = torch.tensor([[-1., -3., 4.], [-3., 3., -1.]])
target = torch.tensor([0, 2])
model = nn.LogSoftmax(dim = 1)  # Along the rows (logits)

criterion = torch.nn.NLLLoss()
print('LogSoftmax and NLLoss:',criterion(model(f), target))



torch.Size([2, 3])
torch.Size([2])
Loss label encoded: tensor(4.5141)
Loss one-hot encoded: tensor(4.5141)
Loss one-hot encoded reversed: tensor(0.2243)
LogSoftmax and NLLoss: tensor(4.5141)


In [51]:
# Custom cost function
import torch.nn.functional as F

def my_cross_entropy(y_pred,y_true):
  # loss = -sum(y_true*log(softmax(y_pred)))

  y_true = y_true.view(y_pred.size(0),-1)
  
  # To handle label-encoded y_true
  if y_pred.size(1) != y_true.size(1):
    y_true = F.one_hot(y_true, num_classes=y_pred.size(1)).view(y_pred.size())
  
  log_softmax = F.log_softmax(y_pred,dim=1)
  loss = - y_true * log_softmax
  loss = loss.sum(dim=1).mean()   # for multiple samples, take the mean()

  return loss


y_pred = torch.tensor([[-1., -3., 4.], [-3., 3., -1.]])
#y_true = torch.tensor([[1., 0., 0.], [0., 0., 1.]])
y_true = torch.tensor([0,2])

print(y_pred)
print(y_true)

loss = my_cross_entropy(y_pred,y_true)

print(loss)
  

tensor([[-1., -3.,  4.],
        [-3.,  3., -1.]])
tensor([0, 2])
torch.Size([2, 3])
torch.Size([2, 1])
tensor([[1, 0, 0],
        [0, 0, 1]])
tensor(4.5141)
