## Load and normalize data
source: https://www.cs.toronto.edu/~kriz/cifar.html

In [None]:
import torch
import numpy as np
import pickle
import os
from torchvision import transforms

INPUT_SIZE = (128, 128)

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict


def load_cifar10_batch(file):
    batch = unpickle(file)
    data = batch[b'data']
    labels = batch[b'labels']

    images = data.reshape((len(data), 3, 32, 32)).transpose(0, 1, 2, 3)
    return torch.tensor(images), torch.tensor(labels)


def load_cifar10(dir):
    # Load all the training batches
    training_data = []
    training_labels = []

    for i in range(1, 6):
        data, labels = load_cifar10_batch(os.path.join(dir, 'data_batch_{}'.format(i)))
        training_data.append(data)
        training_labels.append(labels)

    train_data = torch.cat(training_data)
    train_labels = torch.cat(training_labels)

    # Load the test batch
    test_data, test_labels = load_cifar10_batch(os.path.join(dir, 'test_batch'))

    resize_transform = transforms.Resize(INPUT_SIZE)
    train_data = torch.stack([resize_transform(img) for img in train_data])
    test_data = torch.stack([resize_transform(img) for img in test_data])

    train_data = train_data.float()
    test_data = test_data.float()

    train_data /= 255
    train_data = train_data * 2 - 1
    test_data /= 255
    test_data = test_data * 2 - 1

    return train_data, train_labels, test_data, test_labels


# Assuming you have your data in 'datasets/cifar-10-batches-py/'
train_data, train_labels, test_data, test_labels = load_cifar10('datasets/cifar-10-batches-py')


## Visualize data

In [None]:
print(train_data.shape)
print(train_labels.shape)

In [None]:
import matplotlib.pyplot as plt

plt.hist(train_labels, bins=range(11), align='left', rwidth=0.9)

In [None]:
plt.hist(test_labels, bins=range(11), align='left', rwidth=0.9)

## Create dataset

In [None]:
from augmentation import show_random


show_random(train_data, train_labels,
    probabilities=[0.1, 0.2, 0.2, 0.2, 0.15, 0.15], print_filters=True)

In [None]:
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.utils.data import Dataset
from augmentation import apply_random_filters


class CustomDataset(Dataset):
    pass

BATCH = 8

def transform_wrapper(image):
    return apply_random_filters(image, probabilities=[0.2, 0.2, 0.3, 0.3])

train_dataset = CustomDataset(train_data, train_labels.long(), transform=transform_wrapper)
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH, shuffle=True)

# For the test dataset, usually, we don't apply augmentations
test_dataset = CustomDataset(test_data, test_labels.long())
test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH, shuffle=False)

## Create model

In [None]:
import residual
import torch.nn as nn

# Initialize the model
model = residual.CNNv1(10)        


model.apply(residual.init_weights)

# Print the model
print(model)

## Train model

In [None]:
import callback as cb
import torch.optim as optim
from fit import fit


LR = 0.0001
EPOCH = 45

# Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

# Call the fit function
fit(model, train_loader, test_loader, criterion, optimizer,
    epochs=45, classes=10, backup_path='backup/cifar10/residual_cnn_v1', stopper=True)