# Custom training loop in PyTorch

In [None]:
import torch

One of the benefits of using the Keras API is that it handles the execution of the training loop behind the scenes, minimising the need for boiler plate code. Keras is also sufficiently flexible that it can take care of most custom training algorithms; all models and training pipelines in this module can be done in Keras.

Nevertheless, for completeness, in this notebook we will see how a low-level training loop can be implemented directly in PyTorch using the automatic differentiation tools we have covered. This approach breaks down the training loop and can give you extra flexibility when you need it. In this notebook, we will implement everything in PyTorch in order to provide a complete example for how a training loop can be constructed in this framework. Throughout the rest of the course, we will always use Keras to construct models, optimizers and losses.

We will demonstrate the implementation of the training loop using a classifier model on the Fashion-MNIST dataset.

In [None]:
# Load the Fashion-MNIST dataset

from torchvision import datasets
from torchvision import transforms

train_dataset = datasets.FashionMNIST(
    root="~/torchdata",
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)
val_dataset = datasets.FashionMNIST(
    root="~/torchdata",
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [None]:
# Get the class labels

classes = train_dataset.classes
print(classes)

In [None]:
# View a few training data examples

import numpy as np
import matplotlib.pyplot as plt

n_rows, n_cols = 3, 5
random_inx = np.random.choice(len(train_dataset), n_rows * n_cols, replace=False)
fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 8))
fig.subplots_adjust(hspace=0.2, wspace=0.1)

for n, i in enumerate(random_inx):
    image = torch.squeeze(train_dataset[i][0], dim=0)
    row = n // n_cols
    col = n % n_cols
    axes[row, col].imshow(image)
    axes[row, col].get_xaxis().set_visible(False)
    axes[row, col].get_yaxis().set_visible(False)
    axes[row, col].text(10., -1.5, f'{classes[train_dataset[i][1]]}')
plt.show()

In [None]:
# Create Dataloader

batch_size = 3
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

Below we will build the model in PyTorch for demonstration purposes. In this course we will always use Keras for model building in order to be backend agnostic. The `torch.nn` module contains many important building blocks for neural network models. As you can see below, the PyTorch API is similar to Keras.

In [None]:
# Build the model

from torch.nn import Sequential, Flatten, Linear, ReLU

fashion_mnist_model = Sequential(
    Flatten(),
    Linear(784, 64),
    ReLU(),
    Linear(64, 64),
    ReLU(),
    Linear(64, 10)
)

PyTorch model layers/submodules (referred to as 'children') can be extracted using the `children` or `named_children` methods.

In [None]:
for name, submodule in fashion_mnist_model.named_children():
    print(name, submodule)

An alternative (and common) way to create the same model as above is to subclass `torch.nn.Module`. Layers are defined in the `__init__` method, and the forward pass is defined in `forward`:

In [None]:
# Build the model

class FashionMNISTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = Flatten()
        self.linear1 = Linear(784, 64)
        self.linear2 = Linear(64, 64)
        self.linear3 = Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        return self.linear3(x)

fashion_mnist_model = FashionMNISTModel()

The code below would still work if the model above is built with the Keras API using the `torch` backend. Likewise for the loss and optimizer below.

In [None]:
# Define an optimiser

rmsprop = torch.optim.RMSprop(fashion_mnist_model.parameters(), lr=0.005)

In [None]:
# Define the loss function

loss_fn = torch.nn.CrossEntropyLoss()

The training loop consists of an outer loop than iterates through the epochs, and an inner loop that iterates through the dataset. At each inner iteration, we extract a batch of examples from the dataset, get the model predictions, compute the loss and apply a gradient update.

In [None]:
# Build the custom training loop

import time

epochs = 5
start = time.perf_counter()
for epoch in range(epochs):

    fashion_mnist_model.train()
    losses = []
    for images, labels in train_loader:
        rmsprop.zero_grad()
        logits = fashion_mnist_model(images)
        batch_loss = loss_fn(logits, labels)
        batch_loss.backward()
        rmsprop.step()        
        losses.append(batch_loss.item())

    fashion_mnist_model.eval()
    val_losses = []
    with torch.no_grad():
        for images, labels in val_loader:
            logits = fashion_mnist_model(images)
            batch_loss = loss_fn(logits, labels)
            val_losses.append(batch_loss.numpy())
    
    print(f"End of epoch {epoch}, training loss: {np.mean(losses):.4f}, validation loss: {np.mean(val_losses):.4f}")
print(f"End of training, time: {time.perf_counter() - start:.4f}")

Some explanations are required in the above code:

* Before the training loop, we set `fashion_mnist_model.train()` and before the validation loop, we set `fashion_mnist_model.eval()`. These commands set the model into different training/evaluation modes, and is necessary for layers such as batch normalization or dropout, that behave differently at training time and at test time. To check the model mode, you can use the `.training` flag.
* Before running the forward pass in training, we set `rmsprop.zero_grad()`. This sets all accumulated gradients (on the leaf Tensors) to zero. Otherwise, the gradients would accumulate every time we call `batch_loss.backward()` (which we might want to do, for example if we want to use gradients averaged over multiple minibatches).
    * We could alternatively write `fashion_mnist_model.zero_grad()`, which would be equivalent since we defined the `rmsprop` optimizer using all of the parameters of the model.
* `rmsprop.step()` is the line that applies the optimizer update once the gradients have been calculated with `batch_loss.backward()`.
* When performing the validation steps, we run the model with the `torch.no_grad` context. This context tells PyTorch that we do not want to compute gradients of any operations carried out within that context. In other words, any Tensor defined within this context will have `requires_grad` set to `False`, and we will not be able to call `backward()` on any of these Tensors. This saves a lot of computational overhead, and is useful when we are just evaluating a model and don't want to take gradients.

Watch out for the following regarding loss function signatures: while all NumPy/TensorFlow/JAX/Keras APIs (as well as Python unittest APIs) use the argument order convention `fn(y_true, y_pred)` (reference values first, predicted values second), PyTorch uses `fn(y_pred, y_true)` for its losses. So make sure to invert the order of logits and targets.