📝 **Author:** Amirhossein Heydari - 📧 **Email:** <amirhosseinheydari78@gmail.com> - 📍 **Origin:** [mr-pylin/pytorch-workshop](https://github.com/mr-pylin/pytorch-workshop)

---


**Table of contents**<a id='toc0_'></a>    
- [Dependencies](#toc1_)    
- [Load CIFAR-10 Dataset](#toc2_)    
- [Transforms](#toc3_)    
  - [Built-in Transforms](#toc3_1_)    
    - [Geometry](#toc3_1_1_)    
      - [Resize](#toc3_1_1_1_)    
      - [Cropping](#toc3_1_1_2_)    
      - [Others](#toc3_1_1_3_)    
    - [Color](#toc3_1_2_)    
    - [Composition](#toc3_1_3_)    
    - [Miscellaneous](#toc3_1_4_)    
    - [Conversion](#toc3_1_5_)    
    - [Auto-Augmentation](#toc3_1_6_)    
  - [Custom Transforms](#toc3_2_)    
    - [Approach 1: Using nn.Module](#toc3_2_1_)    
    - [Approach 2: Using v2.Transform](#toc3_2_2_)    
  - [A Typical Transform Pipeline](#toc3_3_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc1_'></a>[Dependencies](#toc0_)


In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2

In [None]:
# set a seed for deterministic results
seed = 42
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# update paths as needed based on your project structure
DATASET_DIR = Path("../../datasets")

In [None]:
def plot(x1: list[torch.Tensor], x2: list[torch.Tensor], y: list[int], transform: v2.Transform) -> None:
    fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(16, 8), layout="compressed")
    for i, (img1, img2, label) in enumerate(zip(x1, x2, y)):
        axs[0, i].imshow(img1.permute(1, 2, 0))
        axs[0, i].set(title="Original")
        axs[1, i].imshow(img2.permute(1, 2, 0))
        axs[1, i].set(title=str(transform).split("(")[0])
    plt.show()

# <a id='toc2_'></a>[Load CIFAR-10 Dataset](#toc0_)


In [None]:
trainset = CIFAR10(DATASET_DIR, train=True, transform=v2.ToImage(), download=False)

x = [trainset[i][0] for i in range(5)]
y = [trainset[i][1] for i in range(5)]

In [None]:
# plot
fig, axs = plt.subplots(nrows=1, ncols=5, figsize=(16, 4), layout="compressed")
for i, (img, label) in enumerate(zip(x, y)):
    axs[i].imshow(img.permute(1, 2, 0))
    axs[i].set(title=label)
plt.show()

# <a id='toc3_'></a>[Transforms](#toc0_)

- Torchvision supports common **computer vision** transformations in the `torchvision.transforms` and `torchvision.transforms.v2` modules.
- Transforms can be used to **transform** or **augment** data for **training** or **inference** of different tasks (image classification, detection, segmentation, video classification).

📝 **Docs**:

- Transforming and augmenting images: [docs.pytorch.org/vision/stable/transforms.html](https://docs.pytorch.org/vision/stable/transforms.html)
- Illustration of transforms: [docs.pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_illustrations.html](https://docs.pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_illustrations.html)

✍️ **Performance considerations**:

- Rely on the v2 transforms from `torchvision.transforms.v2`.
- Use tensors instead of `PIL` images.
- Use `torch.uint8` dtype, especially for resizing.
- Resize with **bilinear** or **bicubic** mode.
- Consider `num_workers > 0` when creating `torch.utils.data.DataLoader`.


## <a id='toc3_1_'></a>[Built-in Transforms](#toc0_)


### <a id='toc3_1_1_'></a>[Geometry](#toc0_)

<table style="margin: 0 auto;">
  <thead>
    <tr>
      <th>Resizing</th>
      <th>Cropping</th>
      <th>Other</th>
    </tr>
  </thead>
  <tbody style="font-family: monospace">
    <tr>
      <td>v2.Resize()</td>
      <td>v2.RandomCrop()</td>
      <td>v2.RandomHorizontalFlip()</td>
    </tr>
    <tr>
      <td>v2.RandomResize()</td>
      <td>v2.RandomResizedCrop()</td>
      <td>v2.RandomVerticalFlip()</td>
    </tr>
    <tr>
      <td></td>
      <td>v2.RandomIoUCrop()</td>
      <td>v2.Pad()</td>
    </tr>
    <tr>
      <td></td>
      <td>v2.CenterCrop()</td>
      <td>v2.RandomZoomOut()</td>
    </tr>
    <tr>
      <td></td>
      <td></td>
      <td>v2.RandomRotation()</td>
    </tr>
    <tr>
      <td></td>
      <td></td>
      <td>v2.RandomAffine()</td>
    </tr>
  </tbody>
</table>


#### <a id='toc3_1_1_1_'></a>[Resize](#toc0_)


In [None]:
t_resize = v2.Resize(size=(16, 16), interpolation=v2.InterpolationMode.NEAREST)
x_resize = t_resize(x)

# plot
plot(x, x_resize, y, t_resize)

In [None]:
t_random_resize = v2.RandomResize(min_size=8, max_size=64)
x_random_resize = [t_random_resize(img) for img in x]

# plot
plot(x, x_random_resize, y, t_random_resize)

#### <a id='toc3_1_1_2_'></a>[Cropping](#toc0_)


In [None]:
t_random_crop = v2.RandomCrop(size=(24, 24))
x_random_crop = [t_random_crop(img) for img in x]

# plot
plot(x, x_random_crop, y, t_random_crop)

In [None]:
t_center_crop = v2.CenterCrop(size=(16, 16))
x_center_crop = t_center_crop(x)

# plot
plot(x, x_center_crop, y, t_center_crop)

#### <a id='toc3_1_1_3_'></a>[Others](#toc0_)


In [None]:
t_random_horizontal_flip = v2.RandomHorizontalFlip()
x_random_horizontal_flip = [t_random_horizontal_flip(img) for img in x]

# plot
plot(x, x_random_horizontal_flip, y, t_random_horizontal_flip)

In [None]:
t_pad = v2.Pad(padding=(1, 2, 3, 4), fill=0, padding_mode="constant")
x_pad = t_pad(x)

# plot
plot(x, x_pad, y, t_pad)

In [None]:
t_random_zoomout = v2.RandomZoomOut(fill=0)
x_random_zoomout = [t_random_zoomout(img) for img in x]

# plot
plot(x, x_random_zoomout, y, t_random_zoomout)

In [None]:
t_random_rotation = v2.RandomRotation(degrees=45, expand=True)
x_random_rotation = [t_random_rotation(img) for img in x]

# plot
plot(x, x_random_rotation, y, t_random_rotation)

In [None]:
t_random_affine = v2.RandomAffine(degrees=45, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=2)
x_random_affine = [t_random_affine(img) for img in x]

# plot
plot(x, x_random_affine, y, t_random_affine)

### <a id='toc3_1_2_'></a>[Color](#toc0_)

<table style="margin: 0 auto;">
  <tbody style="font-family: monospace">
    <tr>
      <td>v2.ColorJitter()</td>
      <td>v2.Grayscale()</td>
      <td>v2.RGB()</td>
    </tr>
    <tr>
      <td>v2.RandomGrayscale()</td>
      <td>v2.GaussianBlur()</td>
      <td>v2.GaussianNoise()</td>
    </tr>
    <tr>
      <td>v2.RandomInvert()</td>
      <td>v2.RandomPosterize()</td>
      <td>v2.RandomSolarize()</td>
    </tr>
    <tr>
      <td>v2.RandomEqualize()</td>
      <td></td>
      <td></td>
    </tr>
  </tbody>
</table>


In [None]:
t_color_jitter = v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
x_color_jitter = [t_color_jitter(img) for img in x]

# plot
plot(x, x_color_jitter, y, t_color_jitter)

In [None]:
t_gray_scale = v2.Grayscale(num_output_channels=3)
x_gray_scale = t_gray_scale(x)

# plot
plot(x, x_gray_scale, y, t_gray_scale)

In [None]:
t_random_solarize = v2.RandomSolarize(threshold=100)
x_random_solarize = [t_random_solarize(img) for img in x]

# plot
plot(x, x_random_solarize, y, t_random_solarize)

In [None]:
t_gaussian_noise = v2.GaussianNoise(mean=0, sigma=0.1)
x_gaussian_noise = [t_gaussian_noise(img.to(torch.float32) / 255) for img in x]

# plot
plot(x, x_gaussian_noise, y, t_gaussian_noise)

### <a id='toc3_1_3_'></a>[Composition](#toc0_)

<table style="margin: 0 auto;">
  <thead>
    <tr>
      <th>Transform</th>
      <th>Always Applied?</th>
      <th>Applies All?</th>
      <th>Order Fixed?</th>
      <th>Use Case</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td style="font-family: monospace;">v2.Compose()</td>
      <td>✅ Yes</td>
      <td>✅ Yes</td>
      <td>✅ Yes</td>
      <td>Fixed pipeline of transformations</td>
    </tr>
    <tr>
      <td style="font-family: monospace;">v2.RandomApply()</td>
      <td>❌ No (probabilistic)</td>
      <td>✅ Yes (if applied)</td>
      <td>✅ Yes</td>
      <td>Conditional application</td>
    </tr>
    <tr>
      <td style="font-family: monospace;">v2.RandomChoice()</td>
      <td>✅ Yes</td>
      <td>❌ No (only one)</td>
      <td>✅ Yes</td>
      <td>Random selection of one transform</td>
    </tr>
    <tr>
      <td style="font-family: monospace;">v2.RandomOrder()</td>
      <td>✅ Yes</td>
      <td>✅ Yes</td>
      <td>❌ No</td>
      <td>Random order of all transforms</td>
    </tr>
  </tbody>
</table>



In [None]:
t_compose = v2.Compose(
    [
        v2.Resize((64, 64)),
        v2.RandomHorizontalFlip(p=0.5),
        v2.ColorJitter(brightness=0.2, contrast=0.2),
    ]
)

x_compose = [t_compose(img) for img in x]

# plot
plot(x, x_compose, y, t_compose)

In [None]:
t_random_apply = v2.RandomApply(
    [
        v2.Resize((64, 64)),
        v2.Grayscale(num_output_channels=3),
    ],
    p=0.5,
)

x_random_apply = [t_random_apply(img) for img in x]

# plot
plot(x, x_random_apply, y, t_random_apply)

In [None]:
t_random_choice = v2.RandomChoice(
    [
        v2.RandomHorizontalFlip(p=1.0),
        v2.RandomVerticalFlip(p=1.0),
    ],
)

x_random_choice = [t_random_choice(img) for img in x]

# plot
plot(x, x_random_choice, y, t_random_choice)

In [None]:
t_random_order = v2.RandomOrder(
    [
        v2.RandomRotation(degrees=30),
        v2.RandomHorizontalFlip(p=1.0),
        v2.ColorJitter(0.2, 0.2, 0.2, 0.2),
    ],
)

x_random_order = [t_random_order(img) for img in x]

# plot
plot(x, x_random_order, y, t_random_order)

In [None]:
# combining compositions
t_combine = v2.Compose(
    [
        v2.Resize((224, 224)),
        v2.RandomApply([v2.ColorJitter(brightness=0.5, contrast=0.5)], p=0.3),
        v2.RandomChoice([v2.GaussianBlur(kernel_size=3), v2.RandomRotation(degrees=30)]),
        v2.RandomOrder(
            [v2.RandomHorizontalFlip(p=0.5), v2.RandomVerticalFlip(p=0.5), v2.Grayscale(num_output_channels=3)]
        ),
    ]
)

x_combine = [t_combine(img) for img in x]

# plot
plot(x, x_combine, y, t_combine)

### <a id='toc3_1_4_'></a>[Miscellaneous](#toc0_)

<table style="margin: 0 auto;">
  <tbody style="font-family: monospace">
    <tr>
      <td>v2.Normalize()</td>
      <td>v2.RandomErasing()</td>
      <td>v2.Lambda()</td>
    </tr>
    <tr>
      <td>v2.JPEG()</td>
      <td></td>
      <td></td>
    </tr>
  </tbody>
</table>


In [None]:
t_normalize = v2.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))
x_normalize = t_normalize([img.to(torch.float32) / 255 for img in x])

# plot
plot(x, x_normalize, y, t_normalize)

In [None]:
t_random_erasing = v2.RandomErasing(p=0.7)
x_random_erasing = [t_random_erasing(img) for img in x]

# plot
plot(x, x_random_erasing, y, t_random_erasing)

### <a id='toc3_1_5_'></a>[Conversion](#toc0_)

<table style="margin: 0 auto;">
  <tbody style="font-family: monospace">
    <tr>
      <td>v2.ToImage()</td>
      <td>v2.ToDtype()</td>
      <td>v2.ToPILImage()</td>
    </tr>
  </tbody>
</table>

✍️ **Note**:

- These transforms are **deprecated**:
  - `v2.ToTensor()`
  - `v2.ConvertImageDtype()`


In [None]:
t_to_image = v2.ToImage()
x_to_image = t_to_image(x)

# plot
plot(x, x_to_image, y, t_to_image)

In [None]:
t_to_dtype = v2.ToDtype(torch.float32, scale=True)
x_to_dtype = t_to_dtype(x)

# plot
plot(x, x_to_dtype, y, t_to_dtype)

### <a id='toc3_1_6_'></a>[Auto-Augmentation](#toc0_)

<table style="margin: 0 auto">
  <thead>
    <tr>
      <th>Method</th>
      <th>Selection</th>
      <th>Applied Transforms</th>
      <th>Strength Control</th>
      <th>Blending</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td><span style="font-family: monospace;">v2.AutoAugment()</span></td>
      <td>Predefined policies</td>
      <td>Fixed sequence</td>
      <td>❌</td>
      <td>❌</td>
    </tr>
    <tr>
      <td><span style="font-family: monospace;">v2.RandAugment()</span></td>
      <td>Random (<span style="font-family: monospace;">num_ops</span>)</td>
      <td>Sequential</td>
      <td>✅ (<span style="font-family: monospace;">magnitude</span>)</td>
      <td>❌</td>
    </tr>
    <tr>
      <td><span style="font-family: monospace;">v2.TrivialAugmentWide()</span></td>
      <td>Random (1 transform)</td>
      <td>Single transform</td>
      <td>❌</td>
      <td>❌</td>
    </tr>
    <tr>
      <td><span style="font-family: monospace;">v2.AugMix()</span></td>
      <td>Random (multiple)</td>
      <td>Mixed (random order)</td>
      <td>✅ (<span style="font-family: monospace;">severity</span>)</td>
      <td>✅ (weighted mix)</td>
    </tr>
  </tbody>
</table>


In [None]:
t_auto_augment = v2.AutoAugment(policy=v2.AutoAugmentPolicy.CIFAR10)
x_auto_augment = [t_auto_augment(img) for img in x]

# plot
plot(x, x_auto_augment, y, t_auto_augment)

In [None]:
t_rand_augment = v2.RandAugment(num_ops=2, magnitude=9)
x_rand_augment = [t_rand_augment(img) for img in x]

# plot
plot(x, x_rand_augment, y, t_rand_augment)

In [None]:
t_trivial_augment_wide = v2.TrivialAugmentWide()
x_trivial_augment_wide = [t_trivial_augment_wide(img) for img in x]

# plot
plot(x, x_trivial_augment_wide, y, t_trivial_augment_wide)

In [None]:
t_aug_mix = v2.AugMix(severity=10, mixture_width=10, chain_depth=-1)
x_aug_mix = [t_aug_mix(img) for img in x]

# plot
plot(x, x_aug_mix, y, t_aug_mix)

## <a id='toc3_2_'></a>[Custom Transforms](#toc0_)

- You can define custom transforms in PyTorch using either `torch.nn.Module` or `torchvision.transforms.v2.Transform`.
- To create a custom transform:
  - Extend `torch.nn.Module` and implement the `forward` method for **simple** transforms.
  - Extend `torchvision.transforms.v2.Transform` and implement the `_transform` method for **advanced** transforms that support **arbitrary** input structures (e.g., **images**, **bounding boxes**, **segmentation masks**).

📝 **Docs**:

- `nn.Module`: [docs.pytorch.org/docs/stable/generated/torch.nn.Module.html](https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html)
- `v2.Transform`: [docs.pytorch.org/vision/stable/generated/torchvision.transforms.v2.Transform.html](https://docs.pytorch.org/vision/stable/generated/torchvision.transforms.v2.Transform.html)
- How to write your own v2 transforms: [docs.pytorch.org/vision/stable/auto_examples/transforms/plot_custom_transforms.html](https://docs.pytorch.org/vision/stable/auto_examples/transforms/plot_custom_transforms.html)


### <a id='toc3_2_1_'></a>[Approach 1: Using nn.Module](#toc0_)

In [None]:
class CustomRandomColorInversion1(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p  # probability of applying the transform

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        # check for unsupported dtypes
        if img.dtype not in (torch.uint8, torch.float16, torch.float32, torch.float64):
            raise ValueError(f"Unsupported dtype: {img.dtype}. Expected uint8 or float.")

        # check if float image is in range [0, 1]
        if img.dtype.is_floating_point:
            if img.min() < 0 or img.max() > 1:
                raise ValueError(f"Float image must be in range [0, 1]. Found range [{img.min()}, {img.max()}].")

        # Apply the transform
        if torch.rand(1) < self.p:
            if img.dtype == torch.uint8:
                img = 255 - img  # invert for uint8
            else:
                img = 1.0 - img  # invert for float

        return img

In [None]:
t_custom_random_color_inversion_1 = CustomRandomColorInversion1()
x_custom_random_color_inversion_1 = [t_custom_random_color_inversion_1(img) for img in x]

# plot
plot(x, x_custom_random_color_inversion_1, y, t_custom_random_color_inversion_1)

### <a id='toc3_2_2_'></a>[Approach 2: Using v2.Transform](#toc0_)

- The `_transform` method takes two arguments:
  - `inpt`: The input to transform (e.g., an **image** tensor).
  - `params`: A **dictionary** of **parameters** generated by the `_get_params` method (if implemented). This is useful for transforms that require dynamic or random parameters.


In [None]:
class CustomRandomColorInversion2(v2.Transform):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p  # probability of applying the transform

    def transform(self, inpt: torch.Tensor, params: dict) -> torch.Tensor:
        # check for unsupported dtypes
        if inpt.dtype not in (torch.uint8, torch.float16, torch.float32, torch.float64):
            raise ValueError(f"Unsupported dtype: {inpt.dtype}. Expected uint8 or float.")

        # check if float image is in range [0, 1]
        if inpt.dtype.is_floating_point:
            if inpt.min() < 0 or inpt.max() > 1:
                raise ValueError(f"Float image must be in range [0, 1]. Found range [{inpt.min()}, {inpt.max()}].")

        # Apply the transform
        if torch.rand(1) < self.p:
            if inpt.dtype == torch.uint8:
                inpt = 255 - inpt  # invert for uint8
            else:
                inpt = 1.0 - inpt  # invert for float

        return inpt

In [None]:
t_custom_random_color_inversion_2 = CustomRandomColorInversion2()
x_custom_random_color_inversion_2 = [t_custom_random_color_inversion_2(img) for img in x]

# plot
plot(x, x_custom_random_color_inversion_2, y, t_custom_random_color_inversion_2)

## <a id='toc3_3_'></a>[A Typical Transform Pipeline](#toc0_)

In [None]:
t_typical = v2.Compose(
    [
        # 1. convert to Tensor (if input is a PIL image or numpy ndarray)
        # only needed if your input is not already a tensor
        v2.ToImage(),
        # 2. ensure the image is in uint8 format (optional)
        # converts to uint8 and scales values to [0, 255]
        # most inputs are already uint8 at this point, so this is optional
        v2.ToDtype(torch.uint8, scale=True),
        # 3. Data augmentation: Randomly resize and crop the image
        # Randomly crops and resizes the image to (224, 224)
        # `antialias=True` improves resizing quality (increases computations)
        v2.RandomResizedCrop(size=(224, 224), antialias=True),
        # 4. Data augmentation: Randomly flip the image horizontally
        # Flips the image horizontally with a probability of 0.5
        v2.RandomHorizontalFlip(p=0.5),
        # 5. Convert to float32 and scale to [0, 1]
        # Converts to float32 and scales values to [0, 1]
        # Required for normalization
        v2.ToDtype(torch.float32, scale=True),
        # 6. Normalize the image using mean and std
        # Normalizes the image
        # These values are standard for ImageNet
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
trainset = CIFAR10(DATASET_DIR, train=True, transform=t_typical, download=False)
train_loader = DataLoader(trainset, batch_size=5)

# plot
for x_transformed, y in train_loader:
    plot(x, x_transformed, y, "t_typical")
    break