### Normalization

Colin has previously spoken about how correctly initializing the weights is important to ensure that the gradients don't explode or vanish.
To help with this process we can change the inputs so they have a mean of 0 and a standard deviation of 1.

It's quite easy to calculate this - just ask the dataset for the mean and standard deviation.

In [None]:
from typing import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim

import numpy as np
from torchvision import datasets, transforms

from tqdm import tqdm

import PIL

In [None]:
def to_image(image: torch.Tensor) -> PIL.Image:
    # the rescaling also reverses the normalization (close enough)
    image -= image.min()
    image /= image.max()
    return transforms.functional.to_pil_image(image.cpu(), 'RGB')

---

The first thing to do is to calculate the normalization values.

In [None]:
# Normalization is important, lets calculate the mean and standard deviation
train_ds = datasets.CIFAR10(root='data', train=True, download=True, transform=transforms.ToTensor())
print(train_ds.data.shape)
print(train_ds.data.mean(axis=(0,1,2))/255)
print(train_ds.data.std(axis=(0,1,2))/255)

---

Now we want to apply them to the dataset.
We can look at the transform to determine what inputs it is expecting:

In [None]:
transforms.Normalize?

So we can see that we can take these direct values and they will be correctly applied - we do not need to negate the mean or take the reciprocal of the standard deviation.

---

The final step is to add the normalization to the transformation list.
The normalization applies to the tensor as it is easiest to transform the values separately.
Lets see how to apply it:

In [None]:
transform = transforms.Compose([
    # Normalization works on tensors
    transforms.ToTensor(),

    # this takes the values that were calculated before
    transforms.Normalize(mean=(0.49139968, 0.48215841, 0.44653091), std=(0.24703223, 0.24348513, 0.26158784)),
])

train_ds = datasets.CIFAR10(
    'data',
    download=True,
    train=True,
    transform=transform,
)

image, target = next(iter(train_ds))
print(train_ds.classes[target])
to_image(image) # this reverses the changes that normalization applies