## Oxford-IIIT Pets Mean and Standard Deviation

This notebook calculates the per-channel mean and standard deviation of the Oxford-IIIT Pets dataset, which are then used for image normalisation in the CNN and ViT notebooks.

In [4]:
import torch
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader
import numpy as np
import sys # for checking if Google Colab is being used
import os # for importing Oxford-IIIT-Pet from where it is locally stored

# Ensures all images are the same dimensions and converts them to tensors
transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
])

# Checking if Google Colab or Jupyter Notebook is being used to determine root path for dataset
if "google.colab" in sys.modules:
    dataset_path = "/content/"
else: 
    dataset_path = os.path.join("..", "oxford-iiit-pet")

# Loads the training and validation data from OxfordIIITPets Dataset
train_dataset = OxfordIIITPet(root=dataset_path, download=True, target_types='category', split='trainval', transform=transform)
# Data is wrapped in a DataLoader to enable iteration over batches
loader = DataLoader(train_dataset, batch_size=32, shuffle=False)

dataset_mean = 0.0
dataset_std = 0.0
total_images = 0 # total image count

# Iterating over the batches
for images, annotations in loader:
  batch_count = images.size(0) # number of images in the batch
  # Result is [batch_size (32), channels (3), pixels per channel]
  image_data = images.view(batch_count, images.size(1), -1)
  # Computes the mean pixel value for each channel in an image, sums those means from all images, then adds that sum to running dataset_mean
  dataset_mean += image_data.mean(2).sum(0)
  # Same process as mean but for standard deviation
  dataset_std += image_data.std(2).sum(0)
  total_images += batch_count

# Final per-channel values calculating by dividing running total by total image count
dataset_mean /= total_images
dataset_std /= total_images

# Outputs the computed values for each RGB channel
print(f"Mean: {dataset_mean}")
print(f"Standard Deviation: {dataset_std}")

Mean: tensor([0.4783, 0.4459, 0.3957])
Standard Deviation: tensor([0.2222, 0.2191, 0.2206])
