# 4b. Loss functions

In [None]:
from typing import Tuple, List

import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
import seaborn as sns
sns.set(style="darkgrid")

from workshop import data
import helper

## Recap

So far, we have seen how 

1. the `torch.Tensor` and `autograd` framework enable the computation of gradients
1. the `torch.nn module` helps us to define a neural network architecture
1. the `torch.utils.data` dataset and a data loader high level APIs encapsulate provisioning of training examples

The one step that is missing from training our own neural network now the computation of a **loss function**.


## A loss function for the MNIST dataset

We have already seen a loss function in our Tensor notebook, namely **mean squared error**.

This loss function, evaluating how **close we are to predicting the correct number** is an adequate choice for a regression problem. But for training digit classification we rather want to evaluate how **close we are to predicting the correct class**.

The latter is typically achieved using the **cross entropy loss** function, which (in our case) measures the **negative log likelihood** for the target class $t$:
$$ce(\hat p) = -\log(p_t)$$

Here, $\hat p$ is the vector of predicted probabilities for each class and $p_t$ is the predicted probability for the target class.

Notice that previously we have purposefully defined our neural network to output a **softmax layer**, which can be interpreted as $\hat p$.
With that convention in mind, we can easliy define our loss function:

In [None]:
def nll(predictions: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    # predictions: probabilities of shape [batch_size, n_classes]
    # targets: values {0, n_classes - 1} of shape [batch_size]
    
    return -predictions[range(len(target)), target].log().mean()

A few things to notice:

* We use pairwise indexing here, which we have seen before
* This works, because our target vector is a 1d vector of class labels that correspond to indices into our predictions.

In [None]:
targets = torch.tensor([0, 1, 2])

perfect_preds = torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])
torch.testing.assert_allclose(nll(perfect_preds, targets), 0)

bad_preds = torch.tensor([[0., 1., 0.], [1., 0., 0.], [1., 0., 0.]])
torch.testing.assert_allclose(nll(bad_preds, targets), np.inf)

some_preds = torch.tensor([[.98, 0.01, 0.01], [0.5, .5, 0.], [0.4, 0.4, .2]])
torch.testing.assert_allclose(nll(some_preds, targets), -(np.log(.98) + np.log(.5) + np.log(0.2))/3.)

## Loss functions in PyTorch

There are a couple of reasons why one might not want to use our custom implementation of the nll loss but prefer to choose [the PyTroch implementation](https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html?highlight=nllloss#torch.nn.NLLLoss).

The documentation for `torch.nn.NLLLoss` explains that it consumes log probabilities. The advantage of this approach is that it provides more numerical stability by using the [log-sum-exp](https://en.wikipedia.org/wiki/LogSumExp) trick.

It also outlines a different approach that avoids using a softmax layer in the model altogether, by means of using [Cross Entropy Loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#crossentropyloss).

## Bottom line
PyTorch provides a **large number of loss functions** that are applicable across a wide range of deep learning tasks and that handle caveats like the one described above.

