# Custom training loop in TensorFlow

In [None]:
import tensorflow as tf

One of the benefits of using the Keras API is that it handles the execution of the training loop behind the scenes, minimising the need for boiler plate code. Keras is also sufficiently flexible that it can take care of most custom training algorithms; all models and training pipelines in this module can be done in Keras.

Nevertheless, for completeness, in this notebook we will see how a low-level training loop can be implemented directly in TensorFlow using the automatic differentiation tools we have covered. (Note that Keras is included as a submodule in TensorFlow, so many of the constructions below will be familiar.) This approach breaks down the training loop and can give you extra flexibility when you need it. 

We will demonstrate the implementation of the training loop using a classifier model on the Fashion-MNIST dataset.

In [None]:
# Load the Fashion-MNIST dataset

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

In [None]:
# Get the class labels

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot"
]

In [None]:
# View a few training data examples

import numpy as np
import matplotlib.pyplot as plt

n_rows, n_cols = 3, 5
random_inx = np.random.choice(x_train.shape[0], n_rows * n_cols, replace=False)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 8))
fig.subplots_adjust(hspace=0.2, wspace=0.1)

for n, i in enumerate(random_inx):
    row = n // n_cols
    col = n % n_cols
    axes[row, col].imshow(x_train[i])
    axes[row, col].get_xaxis().set_visible(False)
    axes[row, col].get_yaxis().set_visible(False)
    axes[row, col].text(10., -1.5, f'{classes[y_train[i]]}')
plt.show()

In [None]:
# Build the model

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Input

def get_model():
    model = Sequential([
        Input(shape=(28, 28)),
        Flatten(),
        Dense(64, activation='relu'),
        Dense(64, activation='relu'),
        Dense(10)
    ], name='fashion_mnist_classifier')
    return model

fashion_mnist_model = get_model()

In [None]:
# Print the model summary

fashion_mnist_model.summary()

In [None]:
# Define an optimiser

rmsprop = tf.keras.optimizers.RMSprop(learning_rate=0.005)

In [None]:
# Define the loss function

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [None]:
# Load the data into tf.data.Dataset objects

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

train_dataset.element_spec

In [None]:
# Shuffle and batch the dataset, and batch the validation dataset

train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size=32)
val_dataset = val_dataset.batch(batch_size=32)

The training loop consists of an outer loop than iterates through the epochs, and an inner loop that iterates through the dataset. At each inner iteration, we extract a batch of examples from the dataset, get the model predictions, compute the loss and apply a gradient update.

In [None]:
# Build the custom training loop

import time

epochs = 5
start = time.perf_counter()
for epoch in range(epochs):

    losses = []
    for images, labels in train_dataset:
        with tf.GradientTape() as tape:
            logits = fashion_mnist_model(images)
            batch_loss = loss_fn(labels, logits)
        grads = tape.gradient(batch_loss, fashion_mnist_model.trainable_weights)
        losses.append(batch_loss.numpy())
        
        rmsprop.apply_gradients(zip(grads, fashion_mnist_model.trainable_weights))

    val_losses = []
    for images, labels in val_dataset:
        logits = fashion_mnist_model(images)
        batch_loss = loss_fn(labels, logits)
        val_losses.append(batch_loss.numpy())
    
    print(f"End of epoch {epoch}, training loss: {np.mean(losses):.4f}, validation loss: {np.mean(val_losses):.4f}")
print(f"End of training, time: {time.perf_counter() - start:.4f}")

A custom training loop such as the one above can often be sped up significantly by compiling the training update step into a computational graph. This allows TensorFlow to make optimisations and can lead to performance gains. The `@tf.function` decorator can be used for this purpose. Below we pull out the main computational step of running the forward and backward passes into a separate function, to which we then apply the decorator to tell TensorFlow to construct a computational graph for this program. See [here](https://www.tensorflow.org/guide/function) for more information. 

In [None]:
# Build a new model and create a new optimizer

fashion_mnist_model = get_model()
rmsprop = tf.keras.optimizers.RMSprop(learning_rate=0.005)

In [None]:
# Optimise the custom training loop by compiling the training step into a graph

@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        logits = fashion_mnist_model(images)
        batch_loss = loss_fn(labels, logits)
    grads = tape.gradient(batch_loss, fashion_mnist_model.trainable_weights)
    return batch_loss, grads

@tf.function
def test_step(images, labels):
    logits = fashion_mnist_model(images)
    batch_loss = loss_fn(labels, logits)
    return batch_loss

epochs = 5
start = time.perf_counter()
for epoch in range(epochs):
    losses = []
    for images, labels in train_dataset:
        batch_loss, grads = train_step(images, labels)
        rmsprop.apply_gradients(zip(grads, fashion_mnist_model.trainable_weights))
    losses.append(batch_loss.numpy())

    val_losses = []
    for images, labels in val_dataset:
        batch_loss = test_step(images, labels)
        val_losses.append(batch_loss.numpy())
    
    print(f"End of epoch {epoch}, training loss: {np.mean(losses):.4f}, validation loss: {np.mean(val_losses):.4f}")
print(f"End of training, time: {time.perf_counter() - start:.4f}")

Note that the Keras API automatically optimises the training loop whenever you call `model.fit`.

In many cases the data processing pipeline can also be optimised for performance gain, see [here](https://www.tensorflow.org/guide/data_performance) for more information.