In [1]:
import torch 
import torch.nn as nn

#### Note that
- computing Cross-entropy loss is preferred with logits value compared to softmax output for numerical stability. So avoid having Softmax layer in the model
- both BCELoss & CrossEntropyLoss func has option of accepting probabilities or logits along with integer labels.
    - nn.BCEWithLogitsLoss() : Logits
    - or nn.BCELoss() : Probabilities
    - nn.CrossEntropyLoss() : Logits
    - nn.NLLLoss() : **Log probabilities**

## Binary class classification

In [8]:
target = torch.tensor([0.0]) # Requires float not Long
logits = torch.tensor([1.2])
prob = torch.sigmoid(logits) # remember nn.Sigmoid() is class

loss_logits_func = nn.BCEWithLogitsLoss() # When logits is used for error calc
loss_prob_func = nn.BCELoss() # When probability is used for error calc

print(f"Error With logits: {loss_logits_func(logits, target)}")
print(f"Error With Probabiltiy: {loss_prob_func(prob, target)}")

Error With logits: 1.4632824659347534
Error With Probabiltiy: 1.4632827043533325


### If you notice carefully above, there is precision error. To avoid such numerical instability, the loss is calculated based on logits than probability

## Multiclass classification

In [17]:
target = torch.tensor([1]) # requires Long, not float
logits = torch.tensor([[2.0,0.1,0.3]])

prob = torch.softmax(logits, dim=1)
# log_prob = torch.log(prob)

loss_logits_fn = nn.CrossEntropyLoss() # logits required used
loss_logprob_fn = nn.NLLLoss() # Probabilty required

print(f"Error With logits: {loss_logits_fn(logits, target)}")
print(f"Error With Probabiltiy: {loss_logprob_fn(torch.log(prob), target)}")

Error With logits: 2.186870813369751
Error With Probabiltiy: 2.186870813369751
