# Import libraries

In [0]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random

# Define data augmentation transforms

In [0]:
transform_no_aug = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_first_aug = transforms.Compose([
    transforms.RandomCrop(32, padding=8),
    transform_no_aug,
])
transform_second_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transform_no_aug,
])
transform_third_aug = transforms.Compose([
    transforms.RandomRotation(15),
    transform_no_aug,
])
transform_fourth_aug = transforms.Compose([
    transforms.RandomPerspective(),
    transform_no_aug,
])
transform_fifth_aug = transforms.Compose([
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transform_no_aug,
])
transform_combined_aug = transforms.Compose([
    transforms.RandomCrop(32, padding=8),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomPerspective(),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transform_no_aug,
])
transform_aug_list = [transform_no_aug, transform_first_aug, transform_second_aug, transform_third_aug, transform_fourth_aug, transform_fifth_aug, transform_combined_aug]
aug_total = 7

# Import CIFAR-10

In [0]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_no_aug)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
num_classes = 10
classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

0it [00:00, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


170500096it [00:06, 26614380.71it/s]                               


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


# Create balanced subsets

In [0]:
# Make each class in subsets balanced
num_subset = 10
class_split = [[], [], [], [], [], [], [], [], [], []]
for i in range(len(trainset)):
    img = trainset.data[i]
    to_tensor = transforms.ToTensor()
    img = to_tensor(img)
    class_split[trainset.targets[i]].append([img, trainset.targets[i]])

# Store dataloader for each subset (10% - 100%) into a list
subset_list = []
for k in range(0, num_subset):
    crt_subset = []
    for i in range(0, num_classes):
        random.shuffle(class_split[i])
        for j in range(0, int(len(trainloader.dataset)*0.1*(k+1)*0.1)):
            crt_subset.append(class_split[i][j])
    random.shuffle(crt_subset)
    subset_list.append(torch.utils.data.DataLoader(crt_subset, batch_size=128, shuffle=True, num_workers=2))