In [1]:
# Understanding Mean and Variance in PyTorch

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [3]:
# We will calculate mean and variance for CIFAR 10
# There are many ways of doing this in pytorch
# Mean: [0.491, 0.482, 0.447] (RGB)
# Var: [0.247, 0.243, 0.261]

In [4]:
from fastai.vision import *
from fastai.callbacks.hooks import *
import fastai
from imageio import imread
from tqdm import tqdm
import numpy as np
from torchvision import transforms

In [5]:
cifar_stats

([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])

In [None]:
# Dowload CIFAR 10

In [4]:
path = untar_data(URLs.CIFAR)
path

PosixPath('/Users/krishna/.fastai/data/cifar10')

In [8]:
TRAIN_ROOT = Path('/Users/krishna/.fastai/data/cifar10/train')
sub_folders_train = ([x for x in TRAIN_ROOT.iterdir() if x.is_dir()])

# PyTorch

In [9]:
img = []
re_mean = 0

for i in tqdm(sub_folders_train):
    for j in i.iterdir():
        img.append(j)      

100%|██████████| 10/10 [00:00<00:00, 16.36it/s]


In [11]:
class cifar_dataset(Dataset):
    def __init__(self, img_array):
        self.data = img_array
        self.to_tensor = transforms.ToTensor()
        
    def __getitem__(self, index):
        x = self.data[index]
        x = imread(x)
        return self.to_tensor(x)
        #return torch.Tensor(x)

    def __len__(self):
        return len(self.data)

In [12]:
dataset = cifar_dataset(img)
loader = DataLoader(
    dataset,
    batch_size=500,
    num_workers=1,
    shuffle=False
)

In [15]:
# 1st Method

In [13]:
def online_mean_and_sd(loader):
    """Compute the mean and sd in an online fashion

        Var[x] = E[X^2] - E^2[X]
    """
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for data in loader:

        b, c, h, w = data.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(data, dim=[0, 2, 3])
        sum_of_square = torch.sum(data ** 2, dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)

        cnt += nb_pixels

    return fst_moment, torch.sqrt(snd_moment - fst_moment ** 2)

In [14]:
mean, std = online_mean_and_sd(loader)

In [18]:
print(mean, std)
# ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])

tensor([0.4914, 0.4822, 0.4465]) tensor([0.2470, 0.2435, 0.2616])


In [16]:
def mean_and_sd(loader):
    mean = 0.
    std = 0.
    nb_batch = 0.
    for data in loader:
        batch_samples = data.size(0)
        channel = data.size(1)
        data = data.view(batch_samples, channel, -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        nb_batch += batch_samples
    return mean/nb_batch, std/nb_batch

In [None]:
mean, std = online_mean_and_sd(loader)

In [19]:
print(mean, std)
# ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])

tensor([0.4914, 0.4822, 0.4465]) tensor([0.2470, 0.2435, 0.2616])


In [26]:
# Here the STD is calculated over batch and averaged
# Hence its not very accurate

# Fast AI

In [22]:
path

PosixPath('/Users/krishna/.fastai/data/cifar10')

In [24]:
ds_tfms = ([*rand_pad(4, 32), flip_lr(p=0.5)], [])
data = ImageDataBunch.from_folder(path, valid='test', ds_tfms=ds_tfms, bs=500).normalize()

In [25]:
data.stats
# # ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])

[tensor([0.4979, 0.4590, 0.4228]), tensor([0.2528, 0.2488, 0.2545])]

In [6]:
# Conclusion:
# Its really hard to calculate varaince when the data sets are too large
# Sometimes its better to calculate the variance of a mini-batch
# Other times online version works well too

In [7]:
# https://stackoverflow.com/questions/48159562/calculating-mean-std-for-batch-python-numpy
# https://forums.fast.ai/t/image-normalization-in-pytorch/7534/11
# https://discuss.pytorch.org/t/about-normalization-using-pre-trained-vgg16-networks/23560
# https://discuss.pytorch.org/t/how-to-normalize-a-tensor-to-0-mean-and-1-variance/18766