In [16]:
import torch
from torch import Tensor

# Tutorial 1b: Softmax Function

**Question:** To have the logistic regressor output probabilities, they need to be processed through a softmax layer. Implement a softmax layer yourself. What numerical issues may arise in this layer? How can you solve them? Use the testing code to confirm you implemented it correctly.

In [17]:
logits = torch.rand((1, 20)) + 100

In [18]:
logits

tensor([[100.7289, 100.1536, 100.0530, 100.7531, 100.5457, 100.8170, 100.0151,
         100.3079, 100.1512, 100.6022, 100.9031, 100.6366, 100.5295, 100.9433,
         100.5648, 100.2664, 100.2527, 100.4309, 100.9932, 100.5948]])

In [19]:
logits.shape

torch.Size([1, 20])

In [20]:
def bad_softmax(x: Tensor) -> Tensor:
    return torch.exp(x) / torch.sum(torch.exp(logits), axis=0)

In [21]:
torch.sum(bad_softmax(logits))

tensor(nan)

#### Response

1. Issue: Overflow
2. Solution: Subtract the maximum value of x along the specified dimension to prevent overflow


In [22]:
def good_softmax(x: Tensor) -> Tensor:
    ###########################################################################
    # TODO: Implement a more stable way to compute softmax                    #
    ###########################################################################
    exp_x = torch.exp(x - torch.max(x, dim=1, keepdim=True)[0])
    denom = torch.sum(exp_x, dim=1, keepdim=True)
    return exp_x/denom
    

In [23]:
torch.sum(good_softmax(logits))

tensor(1.)

Because of numerical issues like the one you just experiences, PyTorch code typically uses a `LogSoftmax` layer.

**Question [optional]:** PyTorch automatically computes the backpropagation gradient of a module for you. However, it can be instructive to derive and implement your own backward function. Try and implement the backward function for your softmax module and confirm that it is correct.

In [10]:
class GoodSoftmax(torch.autograd.Function):
    #@staticmethod
    def forward(ctx, x):
        exp_x = torch.exp(x - torch.max(x, dim=1, keepdim=True)[0])
        denom = torch.sum(exp_x, dim=1, keepdim=True)
        softmax_output = exp_x / denom
        ctx.save_for_backward(softmax_output)
        return softmax_output

    #@staticmethod
    def backward(ctx, grad_output):
        softmax_output, = ctx.saved_tensors
        grad_input = torch.zeros_like(grad_output)
        for i in range(grad_input.shape[0]):
            jacobian = torch.zeros(softmax_output.shape[1], softmax_output.shape[1])
            for k in range(softmax_output.shape[1]):
                for l in range(softmax_output.shape[1]):
                    if k == l:
                        jacobian[k, l] = softmax_output[i, k] * (1 - softmax_output[i, l])
                    else:
                        jacobian[k, l] = -softmax_output[i, k] * softmax_output[i, l]
            grad_input[i, :] = torch.matmul(grad_output[i, :], jacobian)
        return grad_input


In [24]:
x = torch.rand((1, 20), requires_grad=True)
y = GoodSoftmax.apply(x)
loss = y.mean()
loss.backward(retain_graph=True) #no freeding of the intermediate values of the graph
grad_x = torch.autograd.grad(loss, x, create_graph=True)[0]


In [25]:
x

tensor([[0.3152, 0.2651, 0.1640, 0.4944, 0.8559, 0.5110, 0.4593, 0.7517, 0.2050,
         0.2256, 0.0648, 0.3093, 0.6877, 0.2461, 0.7317, 0.3387, 0.3145, 0.2922,
         0.5545, 0.8468]], requires_grad=True)

In [26]:
print(grad_x)

tensor([[-5.8208e-11, -2.0373e-10, -1.4552e-11,  1.3097e-10,  2.0373e-10,
          5.8208e-11, -1.0186e-10, -1.4552e-10,  2.9104e-11, -1.6007e-10,
         -4.3656e-11, -1.4552e-11, -1.4552e-10, -5.8208e-11,  1.4552e-10,
          7.2760e-11, -4.3656e-11,  7.2760e-11, -4.3656e-11,  2.3283e-10]],
       grad_fn=<CopySlices>)


In [27]:
print(x.grad)

tensor([[-5.8208e-11, -2.0373e-10, -1.4552e-11,  1.3097e-10,  2.0373e-10,
          5.8208e-11, -1.0186e-10, -1.4552e-10,  2.9104e-11, -1.6007e-10,
         -4.3656e-11, -1.4552e-11, -1.4552e-10, -5.8208e-11,  1.4552e-10,
          7.2760e-11, -4.3656e-11,  7.2760e-11, -4.3656e-11,  2.3283e-10]])


In [15]:
assert torch.allclose(x.grad, grad_x, rtol=1e-3, atol=1e-5)