In [None]:
%matplotlib inline

In [41]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as nnf
from torchvision import datasets, transforms

In [None]:
data_path = Path("../cifar_data")
# Transform statistics taken from https://stackoverflow.com/a/69750247
cifar10 = datasets.CIFAR10(
    data_path,
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4915, 0.4823, 0.4468),  # Precomputed mean and
                             (0.2470, 0.2435, 0.2616))  # standard deviation
    ]))


Since I got the normalization values from the Internet, I should verify that these statistics are accurate. I will create numpy batch arrays and take the mean and std along the batch and 32x32 pixel axes.

In [36]:
# use np.concatenate to stick all the images together to form a (batch, 3, 32, 32) array
imgs = np.concatenate(
    np.asarray([[
        [
            cifar10[i][0][0].numpy(),
            cifar10[i][0][1].numpy(),
            cifar10[i][0][2].numpy()
        ]
        for i in range(len(cifar10))
    ]])
)

print(imgs.shape)

(50000, 3, 32, 32)


In [39]:
# calculate the mean along the (batch, pixel, pixel) axes
train_mean = np.mean(imgs, axis=(0, 2, 3))
print(train_mean)

[-0.00040607 -0.0005815  -0.00102856]


In [40]:
# calculate the std along the (batch, pixel, pixel) axes
train_std = np.std(imgs, axis=(0, 2, 3))
print(train_std)


[1.0001289  0.9999368  0.99995327]
