📝 **Author:** Amirhossein Heydari - 📧 **Email:** amirhosseinheydari78@gmail.com - 📍 **Linktree:** [linktr.ee/mr_pylin](https://linktr.ee/mr_pylin)

---

# Dependencies

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2

In [2]:
# set a seed for deterministic results
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Load Dataset

In [3]:
trainset = CIFAR10(root="../../datasets", train=True, transform=None, download=False)

x = trainset.data[:3]
y = trainset.targets[:3]

In [None]:
# plot
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(12, 4), layout="compressed")
for i, (img, label) in enumerate(zip(x, y)):
    axs[i].imshow(img)
    axs[i].set(title=label)
plt.show()

# Transforms
   - [pytorch.org/vision/main/transforms.html](https://pytorch.org/vision/main/transforms.html)

## Common Transformations
   - v2.ToImage
   - v2.ToDtype
   - v2.Normalize

In [None]:
# log
for i in range(len(x)):
    print(f"x[{i}].shape : {x[i].shape}")
    print(f"x[{i}].dtype : {x[i].dtype}")
    print(f"type(x[{i}]) : {type(x[i])}")
    print(f"x[{i}].min() : {x[i].min()}")
    print(f"x[{i}].max() : {x[i].max()}")
    print("-" * 50)

### v2.ToImage
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.ToImage.html](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.ToImage.html)

In [None]:
to_image_transform = v2.ToImage()
x_2 = [to_image_transform(img) for img in x]

# log
for i in range(len(x)):
    print(f"x_2[{i}].shape : {x_2[i].shape}")
    print(f"x_2[{i}].dtype : {x_2[i].dtype}")
    print(f"type(x_2[{i}]) : {type(x_2[i])}")
    print(f"x_2[{i}].min() : {x_2[i].min()}")
    print(f"x_2[{i}].max() : {x_2[i].max()}")
    print("-" * 50)

### v2.ToDtype
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.ToDtype.html#torchvision.transforms.v2.ToDtype](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.ToDtype.html#torchvision.transforms.v2.ToDtype)

In [None]:
to_dtype_transform = v2.ToDtype(dtype=torch.float32, scale=True)
x_3 = [to_dtype_transform(img) for img in x_2]

# log
for i in range(len(x)):
    print(f"x_3[{i}].shape: {x_3[i].shape}")
    print(f"x_3[{i}].dtype: {x_3[i].dtype}")
    print(f"type(x_3[{i}]): {type(x_3[i])}")
    print(f"x_3[{i}].min(): {x_3[i].min()}")
    print(f"x_3[{i}].max(): {x_3[i].max()}")
    print("-" * 50)

### v2.Normalize
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.Normalize.html#torchvision.transforms.v2.Normalize](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.Normalize.html#torchvision.transforms.v2.Normalize)

In [None]:
mus = np.array(x_3).mean(axis=(0, 2, 3))
stds = np.array(x_3).std(axis=(0, 2, 3))

normalize_transform = v2.Normalize(mean=mus, std=stds)
x_4 = [normalize_transform(img) for img in x_3]

# log
for i in range(len(x)):
    print(f"x_4[{i}].shape: {x_4[i].shape}")
    print(f"x_4[{i}].dtype: {x_4[i].dtype}")
    print(f"type(x_4[{i}]): {type(x_4[i])}")
    print(f"x_4[{i}].min(): {x_4[i].min()}")
    print(f"x_4[{i}].max(): {x_4[i].max()}")
    print("-" * 50)

### plot

In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout="compressed")
for i, (img1, img2, label) in enumerate(zip(x, x_4, y)):
    axs[0, i].imshow(img1)
    axs[0, i].set(title="Original")
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title="Normalize(ToDtype(ToImage()))")
plt.show()

## Data Augmentation Techniques
   - v2.RandomCrop
   - v2.Resize
   - v2.RandomVerticalFlip
   - v2.RandomHorizontalFlip
   - v2.RandomRotation
   - v2.ColorJitter
   - v2.RandomAffine
   - ...

### v2.RandomCrop
   - [pytorch.org/vision/main/generated/torchvision.transforms.RandomCrop.html#torchvision.transforms.RandomCrop](https://pytorch.org/vision/main/generated/torchvision.transforms.RandomCrop.html#torchvision.transforms.RandomCrop)

In [None]:
random_crop_transform = v2.RandomCrop(size=(int(x_4[0].shape[1] / 4 * 3), int(x_4[0].shape[2] / 4 * 3)))
x_5 = [random_crop_transform(img) for img in x_4]

# log
for i in range(len(x)):
    print(f"x_5[{i}].shape: {x_5[i].shape}")
    print(f"x_5[{i}].dtype: {x_5[i].dtype}")
    print(f"type(x_5[{i}]): {type(x_5[i])}")
    print(f"x_5[{i}].min(): {x_5[i].min()}")
    print(f"x_5[{i}].max(): {x_5[i].max()}")
    print("-" * 50)

In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout="compressed")
for i, (img1, img2, label) in enumerate(zip(x_4, x_5, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title="x_4")
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title="v2.RandomCrop")
plt.show()

### v2.Resize
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.Resize.html#torchvision.transforms.v2.Resize](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.Resize.html#torchvision.transforms.v2.Resize)

In [None]:
resize_transform = v2.Resize(size=(x[0].shape[0], x[0].shape[1]))
x_6 = [resize_transform(img) for img in x_5]

# log
for i in range(len(x)):
    print(f"x_6[{i}].shape: {x_6[i].shape}")
    print(f"x_6[{i}].dtype: {x_6[i].dtype}")
    print(f"type(x_6[{i}]): {type(x_6[i])}")
    print(f"x_6[{i}].min(): {x_6[i].min()}")
    print(f"x_6[{i}].max(): {x_6[i].max()}")
    print("-" * 50)

In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout="compressed")
for i, (img1, img2, label) in enumerate(zip(x_5, x_6, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title="x_5")
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title="v2.Resize")
plt.show()

### v2.RandomVerticalFlip
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomVerticalFlip.html#torchvision.transforms.v2.RandomVerticalFlip](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomVerticalFlip.html#torchvision.transforms.v2.RandomVerticalFlip)

In [None]:
random_verical_flip_transform = v2.RandomVerticalFlip(p=0.6)
x_7 = [random_verical_flip_transform(img) for img in x_6]

# log
for i in range(len(x)):
    print(f"x_7[{i}].shape: {x_7[i].shape}")
    print(f"x_7[{i}].dtype: {x_7[i].dtype}")
    print(f"type(x_7[{i}]): {type(x_7[i])}")
    print(f"x_7[{i}].min(): {x_7[i].min()}")
    print(f"x_7[{i}].max(): {x_7[i].max()}")
    print("-" * 50)

In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout="compressed")
for i, (img1, img2, label) in enumerate(zip(x_6, x_7, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title="x_6")
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title="v2.RandomVerticalFlip")
plt.show()

### v2.RandomHorizontalFlip
   - [pytorch.org/vision/main/generated/torchvision.transforms.RandomHorizontalFlip.html#torchvision.transforms.RandomHorizontalFlip](https://pytorch.org/vision/main/generated/torchvision.transforms.RandomHorizontalFlip.html#torchvision.transforms.RandomHorizontalFlip)

In [None]:
random_horizontal_flip_transform = v2.RandomHorizontalFlip(p=0.7)
x_8 = [random_horizontal_flip_transform(img) for img in x_7]

# log
for i in range(len(x)):
    print(f"x_8[{i}].shape: {x_8[i].shape}")
    print(f"x_8[{i}].dtype: {x_8[i].dtype}")
    print(f"type(x_8[{i}]): {type(x_8[i])}")
    print(f"x_8[{i}].min(): {x_8[i].min()}")
    print(f"x_8[{i}].max(): {x_8[i].max()}")
    print("-" * 50)

In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout="compressed")
for i, (img1, img2, label) in enumerate(zip(x_7, x_8, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title="x_7")
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title="v2.RandomHorizontalFlip")
plt.show()

### v2.RandomRotation
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomRotation.html#torchvision.transforms.v2.RandomRotation](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomRotation.html#torchvision.transforms.v2.RandomRotation)

In [None]:
random_rotation_transform = v2.RandomRotation(degrees=[0, 45])
x_9 = [random_rotation_transform(img) for img in x_8]

# log
for i in range(len(x)):
    print(f"x_9[{i}].shape: {x_9[i].shape}")
    print(f"x_9[{i}].dtype: {x_9[i].dtype}")
    print(f"type(x_9[{i}]): {type(x_9[i])}")
    print(f"x_9[{i}].min(): {x_9[i].min()}")
    print(f"x_9[{i}].max(): {x_9[i].max()}")
    print("-" * 50)

In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout="compressed")
for i, (img1, img2, label) in enumerate(zip(x_8, x_9, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title="x_8")
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title="v2.RandomRotation")
plt.show()

### v2.ColorJitter
   - [pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html#torchvision.transforms.ColorJitter](https://pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html#torchvision.transforms.ColorJitter)

In [None]:
color_jitter_transform = v2.ColorJitter(brightness=0.7, contrast=0.5, saturation=0.9, hue=0.3)
x_10 = [color_jitter_transform(img) for img in x_9]

# log
for i in range(len(x)):
    print(f"x_10[{i}].shape: {x_10[i].shape}")
    print(f"x_10[{i}].dtype: {x_10[i].dtype}")
    print(f"type(x_10[{i}]): {type(x_10[i])}")
    print(f"x_10[{i}].min(): {x_10[i].min()}")
    print(f"x_10[{i}].max(): {x_10[i].max()}")
    print("-" * 50)

In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout="compressed")
for i, (img1, img2, label) in enumerate(zip(x_9, x_10, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title="x_9")
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title="v2.ColorJitter")
plt.show()

### v2.RandomAffine
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomAffine.html#torchvision.transforms.v2.RandomAffine](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomAffine.html#torchvision.transforms.v2.RandomAffine)

In [None]:
random_affine_transform = v2.RandomAffine(degrees=0, shear=0.5, scale=[0.5, 1.5])
x_11 = [random_affine_transform(img) for img in x_10]

# log
for i in range(len(x)):
    print(f"x_11[{i}].shape: {x_11[i].shape}")
    print(f"x_11[{i}].dtype: {x_11[i].dtype}")
    print(f"type(x_11[{i}]): {type(x_11[i])}")
    print(f"x_11[{i}].min(): {x_11[i].min()}")
    print(f"x_11[{i}].max(): {x_11[i].max()}")
    print("-" * 50)

In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout="compressed")
for i, (img1, img2, label) in enumerate(zip(x_10, x_11, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title="x_10")
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title="v2.RandomAffine")
plt.show()

## Mix transforms
   - v2.Compose
      - [pytorch.org/vision/main/generated/torchvision.transforms.v2.Compose.html#torchvision.transforms.v2.Compose](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.Compose.html#torchvision.transforms.v2.Compose)

In [24]:
transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(dtype=torch.float32, scale=True),
        v2.Normalize(mean=mus, std=stds),
        v2.RandomCrop(size=(int(x_4[0].shape[1] / 4 * 3), int(x_4[0].shape[2] / 4 * 3))),
        v2.Resize(size=(x[0].shape[0], x[0].shape[1])),
        v2.RandomVerticalFlip(p=0.6),
        v2.RandomHorizontalFlip(p=0.7),
        v2.RandomRotation(degrees=[0, 45]),
        v2.ColorJitter(brightness=0.7, contrast=0.5, saturation=0.9, hue=0.3),
        v2.RandomAffine(degrees=0, shear=0.5, scale=[0.5, 1.5]),
    ]
)

In [25]:
x_12 = [transforms(img) for img in x]

In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout="compressed")
for i, (img1, img2, label) in enumerate(zip(x, x_12, y)):
    axs[0, i].imshow(img1)
    axs[0, i].set(title="x")
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title="x_12")
plt.show()

# Effect of transforms

In [None]:
# pytorch dataset
trainset = CIFAR10(root="../../datasets", train=True, transform=None, download=False)

# pytorch subset
num_samples = 10
trainsubset = Subset(trainset, indices=range(num_samples))

# log
print("trainset:")
print(f"\tlen(trainset)          : {len(trainset)}")
print(f"\ttrainset.transform     : {trainset.transform}")
print(f"\ttype(trainset[0][0])   : {type(trainset[0][0])}")
print(f"\ttype(trainset[0][1])   : {type(trainset[0][1])}")
print(f"\ttype(trainset.data[0]) : {type(trainset.data[0])}")
print("-" * 50)
print("trainsubset:")
print(f"\tlen(trainsubset)             : {len(trainsubset)}")
print(f"\trainsubset.dataset.transform : {trainsubset.dataset.transform}")
print(f"\ttype(trainsubset[0][0])      : {type(trainsubset[0][0])}")
print(f"\ttype(trainsubset[0][1])      : {type(trainsubset[0][1])}")

In [None]:
transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32),
    ]
)

# add transforms to the dataset
trainset.transform = transforms

# log
print("trainset:")
print(f"\tlen(trainset): {len(trainset)}")
print(f"\ttrainset.transform:\n{trainset.transform}")
print(f"\ttype(trainset[0][0])   : {type(trainset[0][0])}")
print(f"\ttrainset[0][0].dtype   : {trainset[0][0].dtype}")
print(f"\ttype(trainset[0][1])   : {type(trainset[0][1])}")
print(f"\ttype(trainset.data[0]) : {type(trainset.data[0])}")
print("-" * 50)
print("trainsubset:")
print(f"\tlen(trainsubset): {len(trainsubset)}")
print(f"\ttrainsubset.dataset.transform:\n{trainsubset.dataset.transform}")
print(f"\ttype(trainsubset[0][0]) : {type(trainsubset[0][0])}")
print(f"\ttrainsubset[0][0].dtype : {trainsubset[0][0].dtype}")
print(f"\ttype(trainsubset[0][1]) : {type(trainsubset[0][1])}")

In [None]:
# pytorch dataloader
trainloader = DataLoader(trainsubset, batch_size=num_samples, shuffle=False)
next_iter_trainloader = next(iter(trainloader))

print("trainloader:")
print(f"\ttype(next_iter_trainloader[0]) : {type(next_iter_trainloader[0])}")
print(f"\tnext_iter_trainloader[0].dtype : {next_iter_trainloader[0].dtype}")
print(f"\ttype(next_iter_trainloader[1]) : {type(next_iter_trainloader[1])}")
print(f"\tnext_iter_trainloader[1].dtype : {next_iter_trainloader[1].dtype}")