# Model Customizations

In [1]:
# Import packages
import tensorflow as tf

tf.__version__ # 2.x

'2.7.0'

# Custom Loss Function

There are two ways to write custom loss function:

- The first example creates a function that accepts inputs y_true and y_pred.
- Use subclassing to create custom loss function with advanced features.

## Approach 1: Simple and Elegent

In [2]:
def custom_mean_squared_error(y_true, y_pred):
    return tf.math.reduce_mean(tf.square(y_true - y_pred))

## Approach 2: Advance

If you need a loss function that takes in parameters beside y_true and y_pred, you can subclass the tf.keras.losses.Loss class and implement the following two methods:

- __init__ (self): accept parameters to pass during the call of your loss function
- call(self, y_true, y_pred): use the targets (y_true) and the model predictions (y_pred) to compute the model's loss

Let's say you want to use mean squared error, but with an added term that will de-incentivize prediction values far from 0.5 (we assume that the categorical targets are one-hot encoded and take values between 0 and 1).

In [3]:
class CustomMSE(tf.keras.losses.Loss):
    def __init__(self, regularization_factor=0.1, name="custom_mse"):
        super().__init__(name=name)
        self.regularization_factor = regularization_factor

    def call(self, y_true, y_pred):
        mse = tf.math.reduce_mean(tf.square(y_true - y_pred))
        reg = tf.math.reduce_mean(tf.square(0.5 - y_pred))
        return mse + reg * self.regularization_factor

# Writing Custom Metrics

If you need a metric that isn't part of the API, you can easily create custom metrics by subclassing the tf.keras.metrics.Metric class.

You will need to implement 4 methods:

- __init__(self), in which you will create state variables for your metric.
- update_state(self, y_true, y_pred, sample_weight=None), which uses the targets (y_true) and the model predictions (y_pred) to update the state variables.
- result(self), which uses the state variables to compute the final results.
- reset_states(self), which reinitializes the state of the metric.

In [4]:
class CategoricalTruePositives(tf.keras.metrics.Metric):
    def __init__(self, name="categorical_true_positives", **kwargs):
        super(CategoricalTruePositives, self).__init__(name=name, **kwargs)
        self.true_positives = self.add_weight(name="ctp", initializer="zeros")
    
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
        values = tf.cast(y_true, "int32") == tf.cast(y_pred, "int32")
        values = tf.cast(values, "float32")
        if sample_weight is not None:
            sample_weight = tf.cast(sample_weight, "float32")
            values = tf.multiply(values, sample_weight)
        self.true_positives.assign_add(tf.reduce_sum(values))

    def result(self):
        # Actual metric value at the end of each epoch
        return self.true_positives

    def reset_states(self):
        # The state of the metric will be reset at the start of each epoch.
        self.true_positives.assign(0.0)

# Weighing Samples and Classes

## Class Weights

This is set by passing a dictionary to the class_weight argument to Model.fit().

This dictionary maps class indices to the weight that should be used for samples belonging to this class.
This can be used to balance classes without resampling, or to train a model that gives more importance to a particular class.

For instance, if class "0" is half as represented as class "1" in your data, you could use Model.fit(..., class_weight={0: 1.0, 1: 0.5}).

## Sample Weights

For fine grained control, or if you are not building a classifier, you can use "sample weights".

- When training from NumPy data: Pass the sample_weight argument to Model.fit().
- When training from tf.data or any other sort of iterator: Yield (input_batch, label_batch, sample_weight_batch) tuples.

A "sample weights" array is an array of numbers that specify how much weight each sample in a batch should have in computing the total loss.
It is commonly used in imbalanced classification problems (the idea being to give more weight to rarely-seen classes).

When the weights used are ones and zeros, the array can be used as a mask for the loss function (entirely discarding the contribution of certain samples to the total loss).

# Callbacks

Callbacks are objects that are called at different points during training (at the start of an epoch, at the end of a batch, at the end of an epoch, etc.):

- Doing validation at different points during training (beyond the built-in per-epoch validation)
- Checkpointing the model at regular intervals or when it exceeds a certain accuracy threshold
- Changing the learning rate of the model when training seems to be plateauing
- Stopping training when validation loss starts increasing
- Doing fine-tuning of the top layers when training seems to be plateauing
- Sending email or instant message notifications when training ends or where a certain performance threshold is exceeded etc.

In [5]:
# Define a callback to monitor validation loss for early stopping
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        # Stop training when `val_loss` is no longer improving
        monitor="val_loss",
        # "no longer improving" being defined as "no better than 1e-2 less"
        min_delta=1e-2,
        # "no longer improving" being further defined as "for at least 2 epochs"
        patience=2,
        verbose=1,
    )
]

## Custom Callback

You can create a custom callback by extending the base class tf.keras.callbacks.Callback. A callback has access to its associated model through the class property self.model.

Here's a sample example saving a list of per-batch loss values during training instead of default behaviour of saving it for every epoch.

In [6]:
class LossHistoryBatch(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []

    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))

## Model Checkpointing via Callbacks

When you're training model on relatively large datasets, it's crucial to save checkpoints of your model at frequent intervals.

In [7]:
# Define a callback to save models while monitoring validation loss
# It overwrites the model when validation loss improves
callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath="tmp/mymodel_{epoch}", # Path to save model
        save_best_only=True,  # Overwrite a model only if `val_loss` has improved.
        monitor="val_loss",
        verbose=1,
    )
]

## Learning Rate Scheduler via Callbacks

A common pattern when training deep learning models is to gradually reduce the learning as training progresses. This is generally known as "learning rate decay".

The learning decay schedule could be static (fixed in advance, as a function of the current epoch or the current batch index), or dynamic (responding to the current behavior of the model, in particular the validation loss).

### Static

You can easily use a static learning rate decay schedule by passing a schedule object as the learning_rate argument in your optimizer.

In [8]:
# Set initial learning rate
initial_learning_rate = 0.01

# Define a scheduler
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=100000,
    decay_rate=0.96,
    staircase=True
)

# Define optimizer with a learning rate scheduler
optimizer = tf.keras.optimizers.RMSprop(learning_rate=lr_schedule)

### Dynamic

A dynamic learning rate schedule (for instance, decreasing the learning rate when the validation loss is no longer improving) cannot be achieved with these schedule objects since the optimizer does not have access to validation metrics.

However, callbacks do have access to all metrics, including validation metrics! You can thus achieve this pattern by using a callback that modifies the current learning rate on the optimizer. In fact, this is even built-in as the ReduceLROnPlateau callback.

In [9]:
from tensorflow.keras import backend as k

# A very simple dynamic learning rate scheduler using callbacks
# Note that this is just for the sake of example
class IncreaseLR(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        lr = k.get_value(self.model.optimizer.lr)
        new_lr = lr + 0.001 # Decrease learning rate
        k.set_value(self.model.optimizer.lr, new_lr) # Set new learning rate