In [None]:
from utils import display_image_grid
import matplotlib.pyplot as plt

import torch
from torchvision import transforms
from torchvision.datasets import CIFAR10, ImageFolder
from torch.utils.data import DataLoader

In [None]:
ds = ImageFolder(
    root="../assets/clean_dataset",
    transform=transforms.ToTensor()
)
dl = DataLoader(
    ds, 
    batch_size=2, 
    shuffle=True
)

try:
    batch = next(iter(dl))
except Exception as e:
    print(e)

In [None]:
resize_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])
ds.transform = resize_transforms

images, labels = next(iter(dl))

display_image_grid(images, labels, classes=ds.classes, nrow=5)

In [None]:
mean = torch.zeros(3)
std = torch.zeros(3)
n_pixels = 0

for images, _ in dl:
    b, c, h, w = images.shape   # batch size, channels, height, width
    n_pixels += b * h * w

    mean += images.sum(dim=[0, 2, 3])
    std += (images ** 2).sum(dim=[0, 2, 3])

mean /= n_pixels
std = (std / n_pixels - mean ** 2).sqrt()

print("Mean:", mean)
print("Std:", std)

In [None]:
normalize_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean.tolist(), std=std.tolist())
])

ds_normalized = ImageFolder(
    root="../assets/clean_dataset",
    transform=normalize_transforms
)

dl_normalized = DataLoader(
    ds_normalized, 
    batch_size=10, 
    shuffle=False
)

# get a batch from both dataloaders to compare
normalized_tensor = next(iter(dl_normalized))[0][0]
original_tensor = next(iter(dl))[0][0]

print("Original tensor:", original_tensor[:, 0, 0])  # print first pixel values for each channel
print("Normalized tensor:", normalized_tensor[:, 0, 0])

In [None]:
mean_t = torch.tensor(mean).view(3, 1, 1)
std_t = torch.tensor(std).view(3, 1, 1)

denormalized_tensor = normalized_tensor * std_t + mean_t

fig, axes = plt.subplots(1, 2, figsize=(6, 3))

axes[0].imshow(original_tensor.permute(1, 2, 0))
axes[0].set_title("Original (no normalization)")
axes[0].axis("off")

axes[1].imshow(denormalized_tensor.permute(1, 2, 0))
axes[1].set_title("After normalization\n(denormalized for display)")
axes[1].axis("off")

plt.show()

In [None]:
# these are the standard pre-computed values
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std  = (0.2023, 0.1994, 0.2010)

train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])

validation_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])

In [None]:
train_dataset = CIFAR10(
    root="../assets/cifar10", 
    train=True, 
    download=True, 
    transform=train_transforms
)
test_dataset = CIFAR10(
    root="../assets/cifar10", 
    train=False, 
    download=True, 
    transform=validation_transforms
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=64, 
    shuffle=True
)
test_dataloader = DataLoader(
    test_dataset, 
    batch_size=64, 
    shuffle=False
)

In [None]:
images, labels = next(iter(train_dataloader))
images = images * torch.tensor(cifar10_std).view(3, 1, 1) + torch.tensor(cifar10_mean).view(3, 1, 1)
images = images.clamp(0, 1)

display_image_grid(images, labels, classes=train_dataset.classes, nrow=8)