# Data normalization and standardization

In [1]:
import os
import torch
from torchvision import datasets, transforms

In [2]:
DATA_ROOT = '../data/segmented/'

In [3]:
data_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

In [4]:
image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_ROOT, x), 
                                          transform=data_transform)
                  for x in ['train', 'val', 'test']}
data_loaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                         batch_size=8,
                                         shuffle=False,
                                         num_workers=4)
                for x in ['train', 'val', 'test']}

To perform a data normalization and a standarization it is required to compute mean and standard deviation of the whole dataset

In [5]:
def get_mean_and_std(dataloaders):
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    
    for phase in ['train', 'val', 'test']:
        for data, _ in dataloaders[phase]:
            channels_sum += torch.mean(data, dim=[0, 2, 3])
            channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3])
            num_batches += 1
    
    mean = channels_sum / num_batches
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

In [6]:
get_mean_and_std(data_loaders)

(tensor([0.1220, 0.1220, 0.1220]), tensor([0.2058, 0.2058, 0.2058]))