#  Compute Image Mean & Std (for Normalization)

In [2]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch
import numpy as np

# Define minimal transform for loading
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),  # Important: This turns images to range [0, 1]
])

# Load full dataset
data_dir = "../data/PlantVillage"
dataset = datasets.ImageFolder(root=data_dir, transform=transform)

loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)

mean = torch.zeros(3)
std = torch.zeros(3)
total_images = 0

for images, _ in loader:
    total_images += images.size(0)
    mean += images.mean(dim=[0, 2, 3]) * images.size(0)  # mean over C,H,W per image
    std += images.std(dim=[0, 2, 3]) * images.size(0)    # std over C,H,W per image

mean /= total_images
std /= total_images

print(f"Computed mean: {mean}")
print(f"Computed std:  {std}")


Computed mean: tensor([0.4591, 0.4753, 0.4116])
Computed std:  tensor([0.1812, 0.1573, 0.1957])
