## Numpy vs PyTorch (Softmax)

### Numpy: Cross Entropy Loss

In [31]:
import numpy as np 

In [32]:
# Numpy uses one-hot format
Y = np.array([1, 0, 0])

y_pred1 = np.array([0.7, 0.2, 0.1])
y_pred2 = np.array([0.1, 0.3, 0.6])

l1 = np.sum(-Y * np.log(y_pred1)) # 0.35
l2 = np.sum(-Y * np.log(y_pred2)) # 2.3
print("Numpy Loss 1 = ", l1)               
print("Numpy Loss 2 = ", l2)               

Numpy Loss 1 =  0.35667494393873245
Numpy Loss 2 =  2.3025850929940455


### PyTorch: Softmax + Cross Entropy Loss

In [30]:
import torch
import torch.nn as nn

from torch.autograd import Variable

__Single Prediction__

In [24]:
import torch
import torch.nn as nn

from torch.autograd import Variable

# CrossEntropy Loss
loss = nn.CrossEntropyLoss() # logSoftmax + NLLLoss

# PyTorch uses class tensors not one-hot
Y = Variable(torch.LongTensor([0]), requires_grad=False)

# Input:
## Matrix of one-hot arrays (logits) as predictions
## Puts logits into softmax then CE Loss
y_pred1 = Variable(torch.Tensor([[2.0, 1.0, 0.1]]))
y_pred2 = Variable(torch.Tensor([[0.5, 2.0, 0.3]]))

# Loss
l1 = loss(y_pred1, Y) # 0.42
l2 = loss(y_pred2, Y) # 0.84

print("PyTorch Loss 1 = ", l1.data.item()) 
print("PyTorch Loss 2 = ", l2.data.item())

PyTorch Loss 1 =  0.41702985763549805
PyTorch Loss 2 =  1.840616226196289


__Matrix Prediction__

In [29]:
Y = Variable(torch.LongTensor([2, 0, 1]), requires_grad=False)

# One Hot probabilities
y_pred1 = Variable(torch.Tensor( [[0.1, 0.2, 0.9],    # high value for 2 label
                                  [1.1, 0.1, 0.2],    # high value for 0 label
                                  [0.2, 2.1, 0.1]] )) # high value for 1 label

y_pred2 = Variable(torch.Tensor( [[0.8, 0.2, 0.3],    # bad predictions
                                  [0.2, 0.3, 0.5],
                                  [0.2, 0.2, 0.5]] ))

l1 = loss(y_pred1, Y)  # 0.49
l2 = loss(y_pred2, Y)  # 1.24

print("Batch Loss 1 = ", l1.data.item())
print("Batch Loss 2 = ", l2.data.item())

Batch Loss 1 =  0.4966353178024292
Batch Loss 2 =  1.2388995885849
