# Get mean and standard deviation for normalization

In [None]:
import torch
import numpy as np
import torchvision
from torchvision import datasets, transforms
import os

print("PyTorch Version:", torch.__version__)
print("Torchvision Version:", torchvision.__version__)


In [None]:
# Detect if we have a GPU available
print("CUDA available:", torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


## Settings

In [None]:
baseDir = "./datasets"

In [None]:
def getMeanAndSDT(dataloader):
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for data, _ in dataloader:
        # Mean over batch, height and width, but not over the channels
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_squared_sum += torch.mean(data**2, dim=[0, 2, 3])
        num_batches += 1

    mean = channels_sum / num_batches

    # std = sqrt(E[X^2] - (E[X])^2)
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

## Load Dataset and get mean and sdt

In [None]:
print("[🧮 CALCULATE MEAN AND SDT]")

for dataset in [x for x in os.listdir(baseDir) if os.path.isdir(os.path.join(baseDir, x))]:
    print("\n\n[🗃️ DATASET] {}".format(dataset))
    
    datasetDir = os.path.join(baseDir, dataset)

    for subdataset in [x for x in os.listdir(datasetDir) if os.path.isdir(os.path.join(datasetDir, x))]:
        print("\n[📂 FOLDER] {}".format(subdataset))
        
        subDatasetDir = os.path.join(datasetDir, subdataset)
        
        imageDataset = datasets.ImageFolder(subDatasetDir, transform=transforms.Compose([transforms.ToTensor()]))

        for cls in imageDataset.classes:
            clsIndex = imageDataset.class_to_idx[cls]
            numElements = np.count_nonzero(
                np.array(imageDataset.targets) == clsIndex)
            print("[🧮 # ELEMENTS] {}: {}".format(cls.upper(), numElements))

        dataLoader = torch.utils.data.DataLoader(imageDataset, batch_size=128)

        mean, std = getMeanAndSDT(dataLoader)
        print("[✔️ INFO] Mean: {}\n[✔️ INFO] Std: {}".format(mean, std))
