# Training a ConvNet on CIFAR10


In [None]:
import numpy as np
from datetime import datetime 
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
import torchvision
import matplotlib.pyplot as plt
import time
import collections
from functools import partial
from cifar_model import MobileNet
from utils import *

In [None]:
# check device
device = 'cuda'

# parameters
RANDOM_SEED = 42
learning_rate = 0.001
batch_size_train = 128
batch_size_test = 1000
num_workers = 10
n_classes = 10
activation_cutoff = 99.9

## Data

In [None]:
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(15, translate=(0.05,0.05)),
    #transforms.RandomRotation(15),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

# download and create datasets
download = False
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform_train, download=download)
valid_dataset = datasets.CIFAR10(root='./data', train=False,transform=transform_test, download=download)

# define the data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=num_workers, pin_memory=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size_test, shuffle=False, num_workers=num_workers, pin_memory=True)

In [None]:
# check augmented example image
image, sample = next(iter(DataLoader(dataset=train_dataset, batch_size=1, shuffle=True, num_workers=num_workers)))
plt.imshow(image[0].permute(1, 2, 0))

In [None]:
activations = {}
def save_activation(name, mod, inp, out):
    activations[name] = out # don't detach or move to CPU here

In [None]:
def training_loop(model, criterion, optimizer, train_loader, valid_loader, epochs, device, print_every=1):
    train_losses = []
    valid_losses = []

    for epoch in range(0, epochs):
        handles = []
        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                handles.append(module.register_forward_hook(partial(save_activation, name)))
        model, optimizer, train_loss = train(train_loader, model, criterion, optimizer, device, activations)
        train_losses.append(train_loss)
        [handle.remove() for handle in handles] # remove forward hooks

        with torch.no_grad():
            model, valid_loss = validate(valid_loader, model, criterion, device)
            valid_losses.append(valid_loss)

        if epoch % print_every == (print_every - 1):
            train_acc = get_accuracy(model, train_loader, device=device)
            valid_acc = get_accuracy(model, valid_loader, device=device)
            print(f'{datetime.now().time().replace(microsecond=0)} --- '
                  f'Epoch: {epoch}\t'
                  f'Train loss: {train_loss:.4f}\t'
                  f'Valid loss: {valid_loss:.4f}\t'
                  f'Train accuracy: {100 * train_acc:.2f}\t'
                  f'Valid accuracy: {100 * valid_acc:.2f}')

    plot_losses(train_losses, valid_losses)
    return model, optimizer, (train_losses, valid_losses)

In [None]:
def train(train_loader, model, criterion, optimizer, device, activations):
    model.train()
    running_loss = 0

    for i, (X, y_true) in enumerate(train_loader):
        optimizer.zero_grad()
        X = X.to(device)
        y_true = y_true.to(device)
        # Forward pass
        y_hat = model(X)
        norm = 2
        beta = 1e-4
        loss = criterion(y_hat, y_true)
        # activation regularisation
        for i, (name, activation) in enumerate(sorted(activations.items())):
            loss += 0.1*beta*torch.norm(activation, norm)
        # bn reg
        for j, module in enumerate(model.modules()):
            if isinstance(module, nn.BatchNorm2d):
                loss += 500*beta*torch.norm(module.weight, norm)
        # penalize specific layer
        loss += 300*beta*torch.norm(model.features[1].bottleneck[0].weight, norm)
        # Backward pass
        loss.backward()
        optimizer.step()
        # Clip parameters
#         model = clip_parameters(model)
        running_loss += loss.item() * X.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)
    return model, optimizer, epoch_loss

In [None]:
torch.manual_seed(RANDOM_SEED)

model = MobileNet(n_classes).to(device)
#model = nn.DataParallel(model) #torch.cuda.device_count()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=2e-5)
criterion = nn.CrossEntropyLoss()

In [None]:
model, optimizer, _ = training_loop(model, criterion, optimizer, train_loader, valid_loader, 200, device, print_every=5)

In [None]:
torch.save(model.state_dict(), "./cifar-convnet.pth") # don't forget to set model.eval() after loading