# Loss functions

Loss functions are the target that a model strives to optimize. PyTorch provides a suite of popular loss functions, and we will explore their properties here. 

In [4]:
import torch
import torch.nn.functional as F

## Cross entropy

This loss function is commonly used for multi-class classification problems. In PyTorch's implementation, you need to provide:

- A batch $B$ of dimension $(n, K)$, where $n$ is the number of objects in the batch and $K$ is the number of classes. These represent the model's prediction scores.
- A vector $(y_1, y_2, ..., y_n)$, where $y_i$ is an integer from $0$ to $K-1$ indicating the true class label for each observation in the batch.

For a more detailed and accurate description of the sense check, please refer to the official [torch documentation page](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html).

---

Consider examples where the model made mistakes and where it was correct.

At the following cell, the true label for the observation was 0, but the model assigned high scores to all labels except 0. This is a mistake of the model. 

In [25]:
F.cross_entropy(
    torch.tensor([[0.1,0.5,0.7,0.5]]),
    torch.tensor([0])
)

tensor(1.7589)

In this case, the model predicted class 2 with high confidence, assigning low scores to all other classes. The true label was indeed 2, indicating a correct prediction by the model.

In [26]:
F.cross_entropy(
    torch.tensor([[0.1,0.1,0.7,0.1]]),
    torch.tensor([2])
)

tensor(0.9732)

Here is example that shows typical inputs for cross entropy loss.

In [59]:
n = 1000
classes_count = 5
predictions_batch = torch.randn(n, classes_count)
labels = torch.randint(low=0, high=classes_count, size=[n])
print("predictions batch")
print(predictions_batch[:10])
print("labels")
print(labels[:10])

F.cross_entropy(predictions_batch, labels)

predictions batch
tensor([[-0.8844, -0.3556,  0.9763,  1.2577, -0.6795],
        [ 0.2358, -0.7628, -1.4757,  1.0294, -0.8632],
        [ 0.4365, -0.9948,  0.0902, -0.7807, -0.3827],
        [ 2.1360,  1.1786, -0.0036,  0.0101, -0.4870],
        [ 0.1012, -0.2581,  1.4023,  0.5791,  1.3517],
        [ 0.4122,  0.9149,  1.0000,  0.8574,  0.2017],
        [-0.6575,  1.5868,  0.1800, -0.9720,  0.7293],
        [ 0.5361,  1.8960,  1.0395,  0.5292, -0.0472],
        [ 0.0543, -0.2939, -0.5856,  0.6681, -0.5274],
        [ 0.9524,  0.9101, -0.0125,  0.4696,  0.6906]])
labels
tensor([1, 3, 1, 1, 1, 3, 4, 3, 2, 1])


tensor(1.9175)