# Tutorial 4: Optimization and Initialization
In this tutorial, we will review techniques for optimization and initialization of neural networks. When increasing the depth of neural networks, there are various challenges we face. Most importantly, we need to have a stable gradient flow through the network, as otherwise, we might encounter vanishing or exploding gradients. This is why we will take a closer look at the following concepts: initialization and optimization.

In the first half of the notebook, we will review different initialization techniques, and go step by step from the simplest initialization to methods that are nowadays used in very deep networks. In the second half, we focus on optimization comparing the optimizers SGD, SGD with Momentum, and Adam.

Let’s start with importing our standard libraries:

In [1]:
# Standard libraries
import os
import json
import math
import copy
from typing import Any, Sequence, Callable, NamedTuple, Optional, Tuple

PyTree = Any  # Type definition for PyTree, for readability
import pickle

# Third-party libraries
import numpy as np

# Imports for plotting
import matplotlib.pyplot as plt
from matplotlib import cm

%matplotlib inline
from matplotlib_inline.backend_inline import set_matplotlib_formats

set_matplotlib_formats('svg', 'pdf')  # SVG and PDF are for export
import seaborn as sns

sns.set_theme()

# Progress bar
from tqdm.auto import tqdm

# Jax
import jax
import jax.numpy as jnp
from jax import random
from jax.tree_util import tree_map

# Seeding for random operations
main_rng = random.PRNGKey(42)

# Flax
import flax
from flax import linen as nn
from flax.training import checkpoints, train_state

# Optax
import optax




Set path variables `DATASET_PATH` and `CHECKPOINT_PATH`.

In [4]:
# Path to the folder where the datasets are/shoule be downloaded.
DATASET_PATH = './data'
# Path to the folder where the pretrained models are saved.
CHECKPOINT_PATH = './saved_models/tutorial4'

# Verifying the device that will be used throughout this notebook
print('Device:', jax.devices()[0])

Device: cuda:0


In the last part of the notebook, we will train models using three different optimizers. The pretrained models for those are downloaded below.

In [5]:
pass

## Preparation
Throughout this notebook, we will use a deep fully connected network, similar to our previous tutorial. We will apply the network to FashionMNIST. We start by loading the FashionMNIST dataset:


In [20]:
import torch
import torch.utils.data as data
from torchvision.datasets import FashionMNIST
from torchvision import transforms


def image_to_numpy(image):
    """Transformations applied on each image =>
    bring them into a numpy array and normalize to mean 0 and std 1.
    """
    img = np.array(image, dtype=np.float32)
    img = (img / 255.0 - 0.2860) / 0.3530  # TODO: WTH
    return img


def numpy_collate(batch):
    """Stack the batch elements as numpy arrays. By default, PyTorch stacks them as tensors. For JAX, we need numpy arrays."""
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


def numpy_collate(batch):
    return tree_map(np.asarray, data.default_collate(batch))


# Load the FashionMNIST dataset. We need to split the dataset into training and validation sets.
train_dataset = FashionMNIST(
    root=DATASET_PATH, train=True, transform=image_to_numpy, download=True
)
train_set, val_set = data.random_split(
    train_dataset, [50000, 10000], generator=torch.Generator().manual_seed(42)
)

# Load the test set
test_set = FashionMNIST(
    root=DATASET_PATH, train=False, transform=image_to_numpy, download=True
)

# Create data loaders
train_loader = data.DataLoader(
    train_set, batch_size=1024, shuffle=False, drop_last=False, collate_fn=numpy_collate
)
val_loader = data.DataLoader(
    val_set, batch_size=1024, shuffle=False, drop_last=False, collate_fn=numpy_collate
)
test_loader = data.DataLoader(
    test_set, batch_size=1024, shuffle=False, drop_last=False, collate_fn=numpy_collate
)

In comparison to the previous tutorial, we have changed the parameters of the normalization transformation in `image_to_numpy`. The normalization is now designed to give us an expected mean of 0 and a standard deviation of 1 across pixels. This will be particularly relevant for the discussion about initialization we will look at below, and hence we change it here. It should be noted that in most classification tasks, both normalization techniques (between -1 and 1 or mean 0 and stddev 1) have shown to work well. We can calculate the normalization parameters by determining the mean and standard deviation on the original images:

In [23]:
print('Mean:', (train_dataset.data.float() / 255.0).mean().item())
print('Std:', (train_dataset.data.float() / 255.0).std().item())

Mean: 0.28604060411453247
Std: 0.3530242443084717


We can verify the transformation by looking at the statistics of a single batch:

In [24]:
imgs, _ = next(iter(train_loader))
print(f'Mean: {imgs.mean().item(): 5.3f}')
print(f'Std: {imgs.std().item(): 5.3f}')
print(f'Maximun: {imgs.max().item(): 5.3f}')
print(f'Minimum: {imgs.min().item(): 5.3f}')

Mean:  0.008
Std:  1.009
Maximun:  2.022
Minimum: -0.810


Note that the maximum and minimum are not 1 and -1 anymore, but shifted towards the positive values. This is because FashionMNIST contains a lot of black pixels, similar to MNIST.

Next, we create a linear neural network.

In [None]:
# Network definition
class BaseNetwork(nn.Module):
    features: int
    act_fn: Callable
    num_classes: int = 10
    hidden_sizes: Sequence[int] = (512, 256, 256, 128)
    kernel_init: Callable = nn.linear.default_kernel_init

    @nn.compact
    def __call__(self, x: jax.typing.ArrayLike, return_activations: bool = False):
        x = x.reshape(x.shape[0], -1)  # Flatten the input image to a vector.
        # We collect all activations throughout the network for later visualizations
        # Remember that in jitted functions, unused tensors will anyways be removed.
        activcations = []
        for hd in self.hidden_sizes:
            x = nn.Dense(hd, kernel_init=self.kernel_init)(x)
            activcations.append(x)
            x = self.act_fn(x)
            activcations.append(x)
        x = nn.Dense(self.num_classes, kernel_init=self.kernel_init)(x)
        activcations.append(x)
        return x if not return_activations else (x, activcations)

For the activation functions, we make use of JAX’s and Flax’s library instead of implementing ourselves. However, we also define an `Identity` activation function. Although this activation function would significantly limit the network’s modeling capabilities, we will use it in the first steps of our discussion about initialization (for simplicity).

In [None]:
act_fn_by_name = {
    'tanh': jax.nn.tanh,
    'relu': jax.nn.relu,
    'identity': lambda x: x,
}

Finally, we define a few plotting functions that we will use for our discussions. These functions help us to

1. visualize the weight/parameter distribution inside a network;

2. visualize the gradients that the parameters at different layers receive;

3. the activations, i.e. the output of the linear layers. 

The detailed code is not important, but feel free to take a closer look if interested.