# Overview

This notebook shows the effects of the chosen transforms on CIFAR 10 images.

In [1]:
import numpy as np
import plotly.express as px
from visionpod import config
from torchvision.datasets import CIFAR10
from torchvision import transforms

## Mean and Standard Deviation Used in Normalization and Inverse Normalization

In [2]:
mean = [0.49139968, 0.48215841, 0.44653091]
stddev = [0.24703223, 0.24348513, 0.26158784]
inverse_mean = [-i for i in mean]
inverse_stddev = [1/i for i in stddev]

## Transforms

The CIFAR 10 AutoAugment Policys is shown below:

```python
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
```

In [3]:
cifar_norm = transforms.Normalize(mean=mean, std=stddev)
test_transform = transforms.Compose([transforms.ToTensor()])
train_transform = transforms.Compose(
    [
        transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
    ]
)
norm_train_transform = transforms.Compose(
    [
        transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        cifar_norm,
    ]
)
norm_test_transform = transforms.Compose([transforms.ToTensor(), cifar_norm])
# see https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821
inverse_transform = transforms.Compose(
    [
        transforms.Normalize(mean=[0.0, 0.0, 0.0], std=inverse_stddev),
        transforms.Normalize(mean=inverse_mean, std=[1.0, 1.0, 1.0]),
    ]
)

## Read in the CIFAR 10 Dataset and Take a Single Sample PIL Image

In [4]:
dataset = CIFAR10(config.Paths.dataset, download=False)
image = dataset[0][0]

In [5]:
def create_figure(image, title_text):
    image = np.transpose(image.numpy(), (1, 2, 0))
    fig = px.imshow(image)
    fig.update_layout(
        title=dict(
            text=title_text,
            font_family="Ucityweb, sans-serif",
            font=dict(size=24),
            y=0.05,
            yanchor="bottom",
            x=0.5,
        ),
        height=300,
    )
    return fig

## The Sample Image with No Augmentations or Normalization

In [6]:
raw_image_tensor = test_transform(image)
create_figure(raw_image_tensor, "raw image")

## The Sample Image with CIFAR 10 Augmentation Policy and No Normalization

In [7]:
transformed_image = train_transform(image)
create_figure(transformed_image, "normalized image")

## The Transformed Image with Norming Applied

In [8]:
normed_image = cifar_norm(transformed_image)
create_figure(normed_image, "normalized image")

## The Tranformed Image with Inverse Normalization

In [9]:
inversed_transform_image = inverse_transform(normed_image)
create_figure(inversed_transform_image, 'inversed transform image')

## Check that the Cosine Similarity is 1

In [10]:
from torch.nn.functional import cosine_similarity

In [11]:
cosine_similarity(inversed_transform_image.flatten(), transformed_image.flatten(), dim=0)

tensor(1.0000)