# Lensing Softmax

**Numerical examples of softmax output**

In [110]:
# Python 3.6.2. Pytorch 0.2.0
import torch
from torch.autograd import Variable
import torch.nn.functional as F

# Q is the output of an imagined softmax
Q = Variable(torch.Tensor([.4,.3,.2,.1]), requires_grad=True)
# Equivalently, logQ is the output of an imagined log_softmax
logQ = torch.log(Q)

# Q and logQ, coming in batches of size 1. 
Q_batch = Variable(torch.Tensor([[.4,.3,.2,.1]]), requires_grad=True)
logQ_batch = torch.log(Q_batch)

**Single Example**

In [112]:
# Unlensed loss:
# P indicates that the correct label is class 0.
P = Variable(torch.Tensor([1., 0, 0, 0]), requires_grad=False)
F.kl_div(logQ, P, size_average=False)

tensor(0.9163)

In [113]:
# Lensed Loss:
# Cw indicates that the correct label is interchangeably class 0 or class 1.
Cw = Variable(torch.Tensor([0.5, 0.5, 0, 0]), requires_grad=False)
F.kl_div(logQ, Cw, size_average=False)

tensor(0.3670)

**Batch of Examples**

*First we demonstrate that Cross Entropy Loss can be deconstructed into Log Softmax and NLL Loss*

In [118]:
# A batch containing one example. True label is class 0.
target = torch.tensor([0])
input = torch.rand(1, 5, requires_grad=True)
input

tensor([[ 0.5616,  0.8824,  0.8206,  0.0873,  0.4063]])

In [119]:
F.cross_entropy(input, target)

tensor(1.6394)

In [120]:
F.nll_loss(F.log_softmax(input, dim=1), target)

tensor(1.6394)

*Next, we replace the output of the Log Softmax with our example batched output, `logQ_batch`*

In [122]:
F.nll_loss(logQ_batch, target)

tensor(0.9163)

*Note that the result is equivalent to the KL Divergence of Q and P*

In [126]:
# Unlensed loss:
# P is a delta distribution indicating that the correct label is class 0.
P = Variable(torch.Tensor([1., 0, 0, 0]), requires_grad=False)
F.kl_div(logQ, P, size_average=False)

tensor(0.9163)

*Therefore, unlensed Cross Entropy Loss is equivalent to KL Divergence with a delta distribution*

In [135]:
# Unlensed batch loss
# P_batch indicates that the correct label for the example is class 0.
P_batch = Variable(torch.Tensor([[1., 0, 0, 0]]), requires_grad=False)
F.kl_div(logQ_batch, P_batch, size_average=False)

tensor(0.9163)

*and* **lensed** *Cross Entropy Loss is equivalent to KL Divergence with the confusion lens*

In [137]:
# Lensed batch loss
# Cw indicates that the correct label is interchangeably class 0 or class 1.
Cw_batch = Variable(torch.Tensor([[0.5, 0.5, 0, 0]]), requires_grad=False)
F.kl_div(logQ_batch, Cw_batch, size_average=False)

tensor(0.3670)

**Which is exactly what we do in our implementation of Lensed Cross Entropy Loss below**

In [234]:
import torch
import torch.nn.functional as F
def lensed_cross_entropy(input, target, confusion_lens, size_average=True):
    logQ_batch = F.log_softmax(input, 1)
    Cw_batch = confusion_lens[target]
    loss = F.kl_div(logQ_batch, Cw_batch, size_average=False)
    if size_average:
        return loss / input.shape[0]
    else:
        return loss

**As a final sanity check, we demonstrate that when we use an identity matrix as our confusion lens, `lensed_cross_entropy` is equivalent to the unlensed `F.cross_entropy`**

In [224]:
input = torch.rand(3, 4, requires_grad=True)
target = torch.LongTensor([0, 2, 2])

In [226]:
identity_lens = Variable(torch.eye(4), requires_grad=False)
identity_lens

tensor([[ 1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  0.,  1.]])

In [227]:
lensed_cross_entropy(input, target, identity_lens, size_average=False)

tensor(4.0221)

In [228]:
F.cross_entropy(input, target, size_average=False)

tensor(4.0221)

In [229]:
lensed_cross_entropy(input, target, identity_lens, size_average=True)

tensor(1.3407)

In [230]:
F.cross_entropy(input, target, size_average=True)

tensor(1.3407)

**So now we have a lensed_cross_entropy_loss function which lets us apply interesting (non-identity) lenses**

In [232]:
confusion_lens = Variable(torch.Tensor([[0.5, 0.5, 0, 0],[0.5, 0.5, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), requires_grad=False)
confusion_lens

tensor([[ 0.5000,  0.5000,  0.0000,  0.0000],
        [ 0.5000,  0.5000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.0000]])

In [235]:
lensed_cross_entropy(input, target, confusion_lens)

tensor(1.2537)

:-)