# Loss functions in TensorFlow

Training neural networks requires the specification of a loss function. The gradient descent algorithm attempts to minimize the loss during training and the choice of loss function is a key part of the training of a deep learning model.

In this notebook we demonstrate how to use different loss functions for training. 



## Imports

In [None]:
import tensorflow as tf

In this notebook we use the Keras API for TensorFlow and import the keras module. The submodule `keras.losses` contains many loss functions that can be used in training the network.

In [None]:
from tensorflow import keras
from tensorflow.keras import layers

The goal of the notebook is to focus on the loss function (and not really the problem at hand) and so we pick what might be considered the "Hello World" problem of machine learning, the MNIST handwritten digit classification problem.

To get the dataset you can follow the guides [here](https://www.tensorflow.org/datasets/overview) and also find the overview of MNIST [here](https://www.tensorflow.org/datasets/catalog/mnist)

Recall that every datapoint in the MNIST dataset is a 28x28 image of a hand written digit. The array that represents this image is flattened into a 784 length array and is the input to the neural network (given below as the Input layer in Keras). We create two fully-connected (dense) hidden layers and we have a final output layer of size 10 (one for each digit) using softmax to make predictions. 



In [None]:
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, activation="softmax", name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

Compiling the model is the next step and at this point we provide both the optimizer and the loss function to optimize. the loss function in this case is SparseCategoricalCrossentropy. Let's break this down:
- Crossentropy means that we are using the cross entropy between predicted values (probabilities in the softmax layer) and the true labels.
- The categorical tells us that this is a categorical data as output (classificaction problem with more than 2 labels)
- The Sparse tells us that the class labels are provided as integer labels (as opposed to being 1-hot encoded).


In [None]:
model.compile(
    optimizer=keras.optimizers.adam,  # Optimizer
    # Loss function to minimize
    loss=keras.losses.SparseCategoricalCrossentropy()
)

You can also provide the loss function as a string. Typically by replacing the CamelCase class with the more traditional python like (snake case) lowercase and underscore.

In [None]:
model.compile(
    optimizer="rmsprop",
    loss="sparse_categorical_crossentropy"
)

There are a lot of built-in loss functions in the Keras API in Tensorflow, some of these are listed below. 

### Many built-in loss function are available

- 'MeanSquaredError()'
- 'MeanSquaredLogarithmicError()'
- `MeanAbsoluteError()`
- `BinaryCrossentropy()`
- `Hinge()`
- `SquaredHinge()`
- `MeanAbsoluteError()`
- `KLDivergence()`
- `CosineSimilarity()`

The full list is [here](https://keras.io/api/losses/)



## More example of loss functions

### Mean Squared Error loss function
Computes the mean of squares of errors between labels and predictions.
 
$$loss = \frac{1}{n}\sum_{i=1}^n (y_{true,i} - y_{pred,i})^2$$
 

In [None]:
model.compile(
    optimizer=keras.optimizers.adam,  # Optimizer
    # Loss function to minimize
    loss=keras.losses.MeanSquaredError()
)

## Mean Absolute Error loss function

Computes the mean of absolute difference between labels and predictions.

$$loss = \frac{1}{n}\sum_{i=1}^n |y_{true,i} - y_{pred,i}|$$

Usage with the compile() API:

In [None]:
model.compile(
    optimizer=keras.optimizers.adam,  # Optimizer
    # Loss function to minimize
    loss=keras.losses.MeanAbsoluteError()
)

## Creating Custom Loss functions

You can also invent or create your own loss function. A loss function takes two arguments, the true values (y_true) and the predicted values (y_pred) and returns a number. For example the loss function below replicates the behavior of the `MeanSquaredError`.

In [None]:
def custom_mean_squared_error(y_true, y_pred):
    return tf.math.reduce_mean(tf.square(y_true - y_pred))

In [None]:
model = get_uncompiled_model()
model.compile(optimizer=keras.optimizers.Adam(), loss=custom_mean_squared_error)

In [None]:
# We need to one-hot encode the labels to use MSE
y_train_one_hot = tf.one_hot(y_train, depth=10)
model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)

### Additional parameters to loss functions

You may have wondered why in the above loss functions you are initializing a class, e.g., `MeanSquareError()`. The answer is that this class can accept additional arguments and allow you to build more sophisticated loss functions, or more flexible loss functions, without duplication of code.

If you need a loss function that takes in parameters beside `y_true` and `y_pred`, you
can subclass the `tf.keras.losses.Loss` class and implement the following two methods:

- `__init__(self)`: accept parameters to pass during the call of your loss function
- `call(self, y_true, y_pred)`: use the targets (y_true) and the model predictions
(y_pred) to compute the model's loss



In the example we will add a term to the loss function that will de-incentivize prediction values far from 0.5

In [None]:
class CustomMSE(keras.losses.Loss):
    def __init__(self, regularization_factor=0.1, name="custom_mse"):
        super().__init__(name=name)
        self.regularization_factor = regularization_factor

    def call(self, y_true, y_pred):
        mse = tf.math.reduce_mean(tf.square(y_true - y_pred))
        reg = tf.math.reduce_mean(tf.square(0.5 - y_pred))
        return mse + reg * self.regularization_factor


model = get_uncompiled_model()
model.compile(optimizer=keras.optimizers.Adam(), loss=CustomMSE())

y_train_one_hot = tf.one_hot(y_train, depth=10)
model.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)

### References

1. [Training and evaluation with the built-in methods](https://www.tensorflow.org/guide/keras/train_and_evaluate)
1. [Keras loss functions list and API](https://keras.io/api/losses/)