📝 **Author:** Amirhossein Heydari - 📧 **Email:** amirhosseinheydari78@gmail.com - 📍 **Linktree:** [linktr.ee/mr_pylin](https://linktr.ee/mr_pylin)

---

# Dependencies

In [35]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2

In [36]:
# set a seed for deterministic results
random_state = 42
torch.manual_seed(random_state)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Load Dataset

In [37]:
trainset = CIFAR10(root='./dataset', train=True, transform=None, download=True)

x = trainset.data[:3]
y = trainset.targets[:3]

Files already downloaded and verified


In [None]:
# plot
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(12, 4), layout='compressed')
for i, (img, label) in enumerate(zip(x, y)):
    axs[i].imshow(img)
    axs[i].set(title=label)
plt.show()

# Transforms
   - [pytorch.org/vision/main/transforms.html](https://pytorch.org/vision/main/transforms.html)

## Common Transformations
   - v2.ToImage
   - v2.ToDtype
   - v2.Normalize

In [39]:
# log
for i in range(len(x)):
    print(f"x[{i}].shape : {x[i].shape}")
    print(f"x[{i}].dtype : {x[i].dtype}")
    print(f"type(x[{i}]) : {type(x[i])}")
    print(f"x[{i}].min() : {x[i].min()}")
    print(f"x[{i}].max() : {x[i].max()}")
    print('-' * 50)

x[0].shape : (32, 32, 3)
x[0].dtype : uint8
type(x[0]) : <class 'numpy.ndarray'>
x[0].min() : 0
x[0].max() : 255
--------------------------------------------------
x[1].shape : (32, 32, 3)
x[1].dtype : uint8
type(x[1]) : <class 'numpy.ndarray'>
x[1].min() : 5
x[1].max() : 254
--------------------------------------------------
x[2].shape : (32, 32, 3)
x[2].dtype : uint8
type(x[2]) : <class 'numpy.ndarray'>
x[2].min() : 20
x[2].max() : 255
--------------------------------------------------


### v2.ToImage
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.ToImage.html](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.ToImage.html)

In [40]:
to_image_transform = v2.ToImage()
x_2 = [to_image_transform(img) for img in x]

# log
for i in range(len(x)):
    print(f"x_2[{i}].shape : {x_2[i].shape}")
    print(f"x_2[{i}].dtype : {x_2[i].dtype}")
    print(f"type(x_2[{i}]) : {type(x_2[i])}")
    print(f"x_2[{i}].min() : {x_2[i].min()}")
    print(f"x_2[{i}].max() : {x_2[i].max()}")
    print('-' * 50)

x_2[0].shape : torch.Size([3, 32, 32])
x_2[0].dtype : torch.uint8
type(x_2[0]) : <class 'torchvision.tv_tensors._image.Image'>
x_2[0].min() : 0
x_2[0].max() : 255
--------------------------------------------------
x_2[1].shape : torch.Size([3, 32, 32])
x_2[1].dtype : torch.uint8
type(x_2[1]) : <class 'torchvision.tv_tensors._image.Image'>
x_2[1].min() : 5
x_2[1].max() : 254
--------------------------------------------------
x_2[2].shape : torch.Size([3, 32, 32])
x_2[2].dtype : torch.uint8
type(x_2[2]) : <class 'torchvision.tv_tensors._image.Image'>
x_2[2].min() : 20
x_2[2].max() : 255
--------------------------------------------------


### v2.ToDtype
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.ToDtype.html#torchvision.transforms.v2.ToDtype](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.ToDtype.html#torchvision.transforms.v2.ToDtype)

In [41]:
to_dtype_transform = v2.ToDtype(dtype=torch.float32, scale=True)
x_3 = [to_dtype_transform(img) for img in x_2]

# log
for i in range(len(x)):
    print(f"x_3[{i}].shape: {x_3[i].shape}")
    print(f"x_3[{i}].dtype: {x_3[i].dtype}")
    print(f"type(x_3[{i}]): {type(x_3[i])}")
    print(f"x_3[{i}].min(): {x_3[i].min()}")
    print(f"x_3[{i}].max(): {x_3[i].max()}")
    print('-' * 50)

x_3[0].shape: torch.Size([3, 32, 32])
x_3[0].dtype: torch.float32
type(x_3[0]): <class 'torchvision.tv_tensors._image.Image'>
x_3[0].min(): 0.0
x_3[0].max(): 1.0
--------------------------------------------------
x_3[1].shape: torch.Size([3, 32, 32])
x_3[1].dtype: torch.float32
type(x_3[1]): <class 'torchvision.tv_tensors._image.Image'>
x_3[1].min(): 0.019607843831181526
x_3[1].max(): 0.9960784912109375
--------------------------------------------------
x_3[2].shape: torch.Size([3, 32, 32])
x_3[2].dtype: torch.float32
type(x_3[2]): <class 'torchvision.tv_tensors._image.Image'>
x_3[2].min(): 0.0784313753247261
x_3[2].max(): 1.0
--------------------------------------------------


### v2.Normalize
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.Normalize.html#torchvision.transforms.v2.Normalize](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.Normalize.html#torchvision.transforms.v2.Normalize)

In [42]:
mus = np.array(x_3).mean(axis=(0, 2, 3))
stds = np.array(x_3).std(axis=(0, 2, 3))

normalize_transform = v2.Normalize(mean=mus, std=stds)
x_4 = [normalize_transform(img) for img in x_3]

# log
for i in range(len(x)):
    print(f"x_4[{i}].shape: {x_4[i].shape}")
    print(f"x_4[{i}].dtype: {x_4[i].dtype}")
    print(f"type(x_4[{i}]): {type(x_4[i])}")
    print(f"x_4[{i}].min(): {x_4[i].min()}")
    print(f"x_4[{i}].max(): {x_4[i].max()}")
    print('-' * 50)

x_4[0].shape: torch.Size([3, 32, 32])
x_4[0].dtype: torch.float32
type(x_4[0]): <class 'torchvision.tv_tensors._image.Image'>
x_4[0].min(): -2.086963653564453
x_4[0].max(): 1.9853252172470093
--------------------------------------------------
x_4[1].shape: torch.Size([3, 32, 32])
x_4[1].dtype: torch.float32
type(x_4[1]): <class 'torchvision.tv_tensors._image.Image'>
x_4[1].min(): -2.0096473693847656
x_4[1].max(): 2.059528112411499
--------------------------------------------------
x_4[2].shape: torch.Size([3, 32, 32])
x_4[2].dtype: torch.float32
type(x_4[2]): <class 'torchvision.tv_tensors._image.Image'>
x_4[2].min(): -1.731309175491333
x_4[2].max(): 2.073734760284424
--------------------------------------------------


### plot

In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout='compressed')
for i, (img1, img2, label) in enumerate(zip(x, x_4, y)):
    axs[0, i].imshow(img1)
    axs[0, i].set(title='Original')
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title='Normalize(ToDtype(ToImage()))')
plt.show()

## Data Augmentation Techniques
   - v2.RandomCrop
   - v2.Resize
   - v2.RandomVerticalFlip
   - v2.RandomHorizontalFlip
   - v2.RandomRotation
   - v2.ColorJitter
   - v2.RandomAffine
   - ...

### v2.RandomCrop
   - [pytorch.org/vision/main/generated/torchvision.transforms.RandomCrop.html#torchvision.transforms.RandomCrop](https://pytorch.org/vision/main/generated/torchvision.transforms.RandomCrop.html#torchvision.transforms.RandomCrop)

In [44]:
random_crop_transform = v2.RandomCrop(size=(int(x_4[0].shape[1] / 4 * 3), int(x_4[0].shape[2] / 4 * 3)))
x_5 = [random_crop_transform(img) for img in x_4]

# log
for i in range(len(x)):
    print(f"x_5[{i}].shape: {x_5[i].shape}")
    print(f"x_5[{i}].dtype: {x_5[i].dtype}")
    print(f"type(x_5[{i}]): {type(x_5[i])}")
    print(f"x_5[{i}].min(): {x_5[i].min()}")
    print(f"x_5[{i}].max(): {x_5[i].max()}")
    print('-' * 50)

x_5[0].shape: torch.Size([3, 24, 24])
x_5[0].dtype: torch.float32
type(x_5[0]): <class 'torchvision.tv_tensors._image.Image'>
x_5[0].min(): -1.5571343898773193
x_5[0].max(): 1.9853252172470093
--------------------------------------------------
x_5[1].shape: torch.Size([3, 24, 24])
x_5[1].dtype: torch.float32
type(x_5[1]): <class 'torchvision.tv_tensors._image.Image'>
x_5[1].min(): -1.9632577896118164
x_5[1].max(): 2.045320987701416
--------------------------------------------------
x_5[2].shape: torch.Size([3, 24, 24])
x_5[2].dtype: torch.float32
type(x_5[2]): <class 'torchvision.tv_tensors._image.Image'>
x_5[2].min(): -1.731309175491333
x_5[2].max(): 2.073734760284424
--------------------------------------------------


In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout='compressed')
for i, (img1, img2, label) in enumerate(zip(x_4, x_5, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title='x_4')
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title='v2.RandomCrop')
plt.show()

### v2.Resize
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.Resize.html#torchvision.transforms.v2.Resize](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.Resize.html#torchvision.transforms.v2.Resize)

In [46]:
resize_transform = v2.Resize(size=(x[0].shape[0], x[0].shape[1]))
x_6 = [resize_transform(img) for img in x_5]

# log
for i in range(len(x)):
    print(f"x_6[{i}].shape: {x_6[i].shape}")
    print(f"x_6[{i}].dtype: {x_6[i].dtype}")
    print(f"type(x_6[{i}]): {type(x_6[i])}")
    print(f"x_6[{i}].min(): {x_6[i].min()}")
    print(f"x_6[{i}].max(): {x_6[i].max()}")
    print('-' * 50)

x_6[0].shape: torch.Size([3, 32, 32])
x_6[0].dtype: torch.float32
type(x_6[0]): <class 'torchvision.tv_tensors._image.Image'>
x_6[0].min(): -1.5120868682861328
x_6[0].max(): 1.9321309328079224
--------------------------------------------------
x_6[1].shape: torch.Size([3, 32, 32])
x_6[1].dtype: torch.float32
type(x_6[1]): <class 'torchvision.tv_tensors._image.Image'>
x_6[1].min(): -1.9098613262176514
x_6[1].max(): 2.0131337642669678
--------------------------------------------------
x_6[2].shape: torch.Size([3, 32, 32])
x_6[2].dtype: torch.float32
type(x_6[2]): <class 'torchvision.tv_tensors._image.Image'>
x_6[2].min(): -1.7102888822555542
x_6[2].max(): 2.073734760284424
--------------------------------------------------


In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout='compressed')
for i, (img1, img2, label) in enumerate(zip(x_5, x_6, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title='x_5')
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title='v2.Resize')
plt.show()

### v2.RandomVerticalFlip
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomVerticalFlip.html#torchvision.transforms.v2.RandomVerticalFlip](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomVerticalFlip.html#torchvision.transforms.v2.RandomVerticalFlip)

In [48]:
random_verical_flip_transform = v2.RandomVerticalFlip(p=.6)
x_7 = [random_verical_flip_transform(img) for img in x_6]

# log
for i in range(len(x)):
    print(f"x_7[{i}].shape: {x_7[i].shape}")
    print(f"x_7[{i}].dtype: {x_7[i].dtype}")
    print(f"type(x_7[{i}]): {type(x_7[i])}")
    print(f"x_7[{i}].min(): {x_7[i].min()}")
    print(f"x_7[{i}].max(): {x_7[i].max()}")
    print('-' * 50)

x_7[0].shape: torch.Size([3, 32, 32])
x_7[0].dtype: torch.float32
type(x_7[0]): <class 'torchvision.tv_tensors._image.Image'>
x_7[0].min(): -1.5120868682861328
x_7[0].max(): 1.9321309328079224
--------------------------------------------------
x_7[1].shape: torch.Size([3, 32, 32])
x_7[1].dtype: torch.float32
type(x_7[1]): <class 'torchvision.tv_tensors._image.Image'>
x_7[1].min(): -1.9098613262176514
x_7[1].max(): 2.0131337642669678
--------------------------------------------------
x_7[2].shape: torch.Size([3, 32, 32])
x_7[2].dtype: torch.float32
type(x_7[2]): <class 'torchvision.tv_tensors._image.Image'>
x_7[2].min(): -1.7102888822555542
x_7[2].max(): 2.073734760284424
--------------------------------------------------


In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout='compressed')
for i, (img1, img2, label) in enumerate(zip(x_6, x_7, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title='x_6')
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title='v2.RandomVerticalFlip')
plt.show()

### v2.RandomHorizontalFlip
   - [pytorch.org/vision/main/generated/torchvision.transforms.RandomHorizontalFlip.html#torchvision.transforms.RandomHorizontalFlip](https://pytorch.org/vision/main/generated/torchvision.transforms.RandomHorizontalFlip.html#torchvision.transforms.RandomHorizontalFlip)

In [50]:
random_horizontal_flip_transform = v2.RandomHorizontalFlip(p=.7)
x_8 = [random_horizontal_flip_transform(img) for img in x_7]

# log
for i in range(len(x)):
    print(f"x_8[{i}].shape: {x_8[i].shape}")
    print(f"x_8[{i}].dtype: {x_8[i].dtype}")
    print(f"type(x_8[{i}]): {type(x_8[i])}")
    print(f"x_8[{i}].min(): {x_8[i].min()}")
    print(f"x_8[{i}].max(): {x_8[i].max()}")
    print('-' * 50)

x_8[0].shape: torch.Size([3, 32, 32])
x_8[0].dtype: torch.float32
type(x_8[0]): <class 'torchvision.tv_tensors._image.Image'>
x_8[0].min(): -1.5120868682861328
x_8[0].max(): 1.9321309328079224
--------------------------------------------------
x_8[1].shape: torch.Size([3, 32, 32])
x_8[1].dtype: torch.float32
type(x_8[1]): <class 'torchvision.tv_tensors._image.Image'>
x_8[1].min(): -1.9098613262176514
x_8[1].max(): 2.0131337642669678
--------------------------------------------------
x_8[2].shape: torch.Size([3, 32, 32])
x_8[2].dtype: torch.float32
type(x_8[2]): <class 'torchvision.tv_tensors._image.Image'>
x_8[2].min(): -1.7102888822555542
x_8[2].max(): 2.073734760284424
--------------------------------------------------


In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout='compressed')
for i, (img1, img2, label) in enumerate(zip(x_7, x_8, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title='x_7')
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title='v2.RandomHorizontalFlip')
plt.show()

### v2.RandomRotation
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomRotation.html#torchvision.transforms.v2.RandomRotation](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomRotation.html#torchvision.transforms.v2.RandomRotation)

In [52]:
random_rotation_transform = v2.RandomRotation(degrees=[0, 45])
x_9 = [random_rotation_transform(img) for img in x_8]

# log
for i in range(len(x)):
    print(f"x_9[{i}].shape: {x_9[i].shape}")
    print(f"x_9[{i}].dtype: {x_9[i].dtype}")
    print(f"type(x_9[{i}]): {type(x_9[i])}")
    print(f"x_9[{i}].min(): {x_9[i].min()}")
    print(f"x_9[{i}].max(): {x_9[i].max()}")
    print('-' * 50)

x_9[0].shape: torch.Size([3, 32, 32])
x_9[0].dtype: torch.float32
type(x_9[0]): <class 'torchvision.tv_tensors._image.Image'>
x_9[0].min(): -1.5120868682861328
x_9[0].max(): 1.9321309328079224
--------------------------------------------------
x_9[1].shape: torch.Size([3, 32, 32])
x_9[1].dtype: torch.float32
type(x_9[1]): <class 'torchvision.tv_tensors._image.Image'>
x_9[1].min(): -1.9098613262176514
x_9[1].max(): 2.0131337642669678
--------------------------------------------------
x_9[2].shape: torch.Size([3, 32, 32])
x_9[2].dtype: torch.float32
type(x_9[2]): <class 'torchvision.tv_tensors._image.Image'>
x_9[2].min(): -1.7102888822555542
x_9[2].max(): 2.073734760284424
--------------------------------------------------


In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout='compressed')
for i, (img1, img2, label) in enumerate(zip(x_8, x_9, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title='x_8')
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title='v2.RandomRotation')
plt.show()

### v2.ColorJitter
   - [pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html#torchvision.transforms.ColorJitter](https://pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html#torchvision.transforms.ColorJitter)

In [54]:
color_jitter_transform = v2.ColorJitter(brightness=.7, contrast=.5, saturation=.9, hue=.3)
x_10 = [color_jitter_transform(img) for img in x_9]

# log
for i in range(len(x)):
    print(f"x_10[{i}].shape: {x_10[i].shape}")
    print(f"x_10[{i}].dtype: {x_10[i].dtype}")
    print(f"type(x_10[{i}]): {type(x_10[i])}")
    print(f"x_10[{i}].min(): {x_10[i].min()}")
    print(f"x_10[{i}].max(): {x_10[i].max()}")
    print('-' * 50)

x_10[0].shape: torch.Size([3, 32, 32])
x_10[0].dtype: torch.float32
type(x_10[0]): <class 'torchvision.tv_tensors._image.Image'>
x_10[0].min(): 0.0
x_10[0].max(): 0.9999585747718811
--------------------------------------------------
x_10[1].shape: torch.Size([3, 32, 32])
x_10[1].dtype: torch.float32
type(x_10[1]): <class 'torchvision.tv_tensors._image.Image'>
x_10[1].min(): 0.0
x_10[1].max(): 0.677256166934967
--------------------------------------------------
x_10[2].shape: torch.Size([3, 32, 32])
x_10[2].dtype: torch.float32
type(x_10[2]): <class 'torchvision.tv_tensors._image.Image'>
x_10[2].min(): 0.0
x_10[2].max(): 1.0
--------------------------------------------------


In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout='compressed')
for i, (img1, img2, label) in enumerate(zip(x_9, x_10, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title='x_9')
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title='v2.ColorJitter')
plt.show()

### v2.RandomAffine
   - [pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomAffine.html#torchvision.transforms.v2.RandomAffine](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.RandomAffine.html#torchvision.transforms.v2.RandomAffine)

In [56]:
random_affine_transform = v2.RandomAffine(degrees=0, shear=.5, scale=[0.5, 1.5])
x_11 = [random_affine_transform(img) for img in x_10]

# log
for i in range(len(x)):
    print(f"x_11[{i}].shape: {x_11[i].shape}")
    print(f"x_11[{i}].dtype: {x_11[i].dtype}")
    print(f"type(x_11[{i}]): {type(x_11[i])}")
    print(f"x_11[{i}].min(): {x_11[i].min()}")
    print(f"x_11[{i}].max(): {x_11[i].max()}")
    print('-' * 50)

x_11[0].shape: torch.Size([3, 32, 32])
x_11[0].dtype: torch.float32
type(x_11[0]): <class 'torchvision.tv_tensors._image.Image'>
x_11[0].min(): 0.0
x_11[0].max(): 0.9999585747718811
--------------------------------------------------
x_11[1].shape: torch.Size([3, 32, 32])
x_11[1].dtype: torch.float32
type(x_11[1]): <class 'torchvision.tv_tensors._image.Image'>
x_11[1].min(): 0.0
x_11[1].max(): 0.677256166934967
--------------------------------------------------
x_11[2].shape: torch.Size([3, 32, 32])
x_11[2].dtype: torch.float32
type(x_11[2]): <class 'torchvision.tv_tensors._image.Image'>
x_11[2].min(): 0.0
x_11[2].max(): 1.0
--------------------------------------------------


In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout='compressed')
for i, (img1, img2, label) in enumerate(zip(x_10, x_11, y)):
    axs[0, i].imshow(img1.permute(1, 2, 0))
    axs[0, i].set(title='x_10')
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title='v2.RandomAffine')
plt.show()

## Mix transforms
   - v2.Compose
      - [pytorch.org/vision/main/generated/torchvision.transforms.v2.Compose.html#torchvision.transforms.v2.Compose](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.Compose.html#torchvision.transforms.v2.Compose)

In [58]:
transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(dtype=torch.float32, scale=True),
        v2.Normalize(mean=mus, std=stds),
        v2.RandomCrop(size=(int(x_4[0].shape[1] / 4 * 3), int(x_4[0].shape[2] / 4 * 3))),
        v2.Resize(size=(x[0].shape[0], x[0].shape[1])),
        v2.RandomVerticalFlip(p=.6),
        v2.RandomHorizontalFlip(p=.7),
        v2.RandomRotation(degrees=[0, 45]),
        v2.ColorJitter(brightness=.7, contrast=.5, saturation=.9, hue=.3),
        v2.RandomAffine(degrees=0, shear=.5, scale=[0.5, 1.5]),
    ]
)

In [59]:
x_12 = [transforms(img) for img in x]

In [None]:
# plot
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(12, 8), layout='compressed')
for i, (img1, img2, label) in enumerate(zip(x, x_12, y)):
    axs[0, i].imshow(img1)
    axs[0, i].set(title='x')
    axs[1, i].imshow(img2.permute(1, 2, 0))
    axs[1, i].set(title='x_12')
plt.show()

# Effect of transforms

In [70]:
# pytorch dataset
trainset = CIFAR10(root='./dataset', train=True, transform=None, download=True)

# pytorch subset
num_samples = 10
trainsubset = Subset(trainset, indices=range(num_samples))

# log
print('trainset:')
print(f"\tlen(trainset)          : {len(trainset)}")
print(f"\ttrainset.transform     : {trainset.transform}")
print(f"\ttype(trainset[0][0])   : {type(trainset[0][0])}")
print(f"\ttype(trainset[0][1])   : {type(trainset[0][1])}")
print(f"\ttype(trainset.data[0]) : {type(trainset.data[0])}")
print('-' * 50)
print('trainsubset:')
print(f"\tlen(trainsubset)             : {len(trainsubset)}")
print(f"\trainsubset.dataset.transform : {trainsubset.dataset.transform}")
print(f"\ttype(trainsubset[0][0])      : {type(trainsubset[0][0])}")
print(f"\ttype(trainsubset[0][1])      : {type(trainsubset[0][1])}")

Files already downloaded and verified
trainset:
	len(trainset)          : 50000
	trainset.transform     : None
	type(trainset[0][0])   : <class 'PIL.Image.Image'>
	type(trainset[0][1])   : <class 'int'>
	type(trainset.data[0]) : <class 'numpy.ndarray'>
--------------------------------------------------
trainsubset:
	len(trainsubset)             : 10
	rainsubset.dataset.transform : None
	type(trainsubset[0][0])      : <class 'PIL.Image.Image'>
	type(trainsubset[0][1])      : <class 'int'>


In [66]:
transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32),
    ]
)

# add transforms to the dataset
trainset.transform = transforms

# log
print('trainset:')
print(f"\tlen(trainset): {len(trainset)}")
print(f"\ttrainset.transform:\n{trainset.transform}")
print(f"\ttype(trainset[0][0])   : {type(trainset[0][0])}")
print(f"\ttrainset[0][0].dtype   : {trainset[0][0].dtype}")
print(f"\ttype(trainset[0][1])   : {type(trainset[0][1])}")
print(f"\ttype(trainset.data[0]) : {type(trainset.data[0])}")
print('-' * 50)
print('trainsubset:')
print(f"\tlen(trainsubset): {len(trainsubset)}")
print(f"\ttrainsubset.dataset.transform:\n{trainsubset.dataset.transform}")
print(f"\ttype(trainsubset[0][0]) : {type(trainsubset[0][0])}")
print(f"\ttrainsubset[0][0].dtype : {trainsubset[0][0].dtype}")
print(f"\ttype(trainsubset[0][1]) : {type(trainsubset[0][1])}")

trainset:
	len(trainset): 50000
	trainset.transform:
Compose(
      ToImage()
      ToDtype(scale=False)
)
	type(trainset[0][0])   : <class 'torchvision.tv_tensors._image.Image'>
	trainset[0][0].dtype   : torch.float32
	type(trainset[0][1])   : <class 'int'>
	type(trainset.data[0]) : <class 'numpy.ndarray'>
--------------------------------------------------
trainsubset:
	len(trainsubset): 10
	trainsubset.dataset.transform:
Compose(
      ToImage()
      ToDtype(scale=False)
)
	type(trainsubset[0][0]) : <class 'torchvision.tv_tensors._image.Image'>
	trainsubset[0][0].dtype : torch.float32
	type(trainsubset[0][1]) : <class 'int'>


In [64]:
# pytorch dataloader
trainloader = DataLoader(trainsubset, batch_size=num_samples, shuffle=False)
next_iter_trainloader = next(iter(trainloader))

print('trainloader:')
print(f"\ttype(next_iter_trainloader[0]) : {type(next_iter_trainloader[0])}")
print(f"\tnext_iter_trainloader[0].dtype : {next_iter_trainloader[0].dtype}")
print(f"\ttype(next_iter_trainloader[1]) : {type(next_iter_trainloader[1])}")
print(f"\tnext_iter_trainloader[1].dtype : {next_iter_trainloader[1].dtype}")

trainloader:
	type(next_iter_trainloader[0]) : <class 'torch.Tensor'>
	next_iter_trainloader[0].dtype : torch.float32
	type(next_iter_trainloader[1]) : <class 'torch.Tensor'>
	next_iter_trainloader[1].dtype : torch.int64
