In [None]:
!pip install kornia

In [None]:
import json
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from torchvision import transforms 
from torchvision import datasets as dts

import kornia

In [None]:
# normalize data
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


# load train data
train_dataset = dts.CIFAR10(root='../data', 
                           train=True, 
                           transform=transform,  
                           download=True)
batch_size = 4

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# functions to show an image
from torchvision.utils import make_grid


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

In [None]:
def augument_image(image, data_augs):
    modules = []
    data_augs_list = data_augs.split()

    if 'crop' in data_augs_list:
        modules.append(nn.ReplicationPad2d(4))
        modules.append(kornia.augmentation.RandomCrop(size=(64, 64)))
    if 'jitter' in data_augs_list:
        modules.append(kornia.augmentation.ColorJitter(0.2, 0.3, 0.2, 0.3))
    if 'erase' in data_augs_list:
        modules.append(kornia.augmentation.RandomErasing())
    if 'hflip' in data_augs_list:
        modules.append(kornia.augmentation.RandomHorizontalFlip())
    if 'vflip' in data_augs_list:
        modules.append(kornia.augmentation.RandomVerticalFlip())
    if 'rot' in data_augs_list:
        modules.append(kornia.augmentation.RandomRotation(degrees=5.0))

    transforms = nn.Sequential(*modules)
    augumented_image = transforms(image + 0.5)
    return augumented_image - 0.5

In [None]:
aug_images = augument_image(images ,'jitter')

In [None]:
# show images
imshow(make_grid(aug_images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

In [None]:
aug_images = augument_image(images ,'rot')

# show images
imshow(make_grid(aug_images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

In [None]:
aug_images = augument_image(images ,'hflip')

# show images
imshow(make_grid(aug_images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

In [None]:
aug_images = augument_image(images ,'jitter hflip')

# show images
imshow(make_grid(aug_images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size))) 