## Learning Data Augmentation

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn as nn
import torchvision

import datasets
import autoencoder
import cnn
import vae
import utils

## Dataset

In [None]:
trn_dset, tst_dset = datasets.get_cifar_dataset(trn_size=5000, tst_size=5000)
trn_loader, tst_loader = datasets.get_cifar_loader(trn_dset, tst_dset, batch_size=64)
inputs,targets = next(iter(trn_loader))
utils.plot_batch(inputs)
print("Train:", len(trn_loader.dataset), "Test:", len(tst_loader.dataset), 
      "Input:", inputs.size(), "Target:", targets.size())

## Classifier

In [None]:
model = cnn.CNN(in_shape=(3,32,32), n_classes=10).cuda()

In [None]:
epochs = 20
iters = epochs * len(trn_loader)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
lr_adjuster = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)
trainer = cnn.Trainer(optimizer, lr_adjuster)

In [None]:
trainer.run(model, trn_loader, tst_loader, criterion, epochs)

In [None]:
utils.plot_metric(trainer.metrics['loss']['trn'], trainer.metrics['loss']['tst'], 'Loss')
utils.plot_metric(trainer.metrics['accuracy']['trn'], trainer.metrics['accuracy']['tst'], 'Accuracy')

## Autoencoder

In [None]:
model = autoencoder.ConvAE(in_shape=(3,32,32)).cuda()

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)

In [None]:
losses = autoencoder.run(model, trn_loader, criterion, optimizer, epochs=50)
utils.plot_metric(losses, losses, 'Loss')

## VAE

In [None]:
model = vae.VAE(in_shape=(3,32,32), n_latent=100).cuda()

In [None]:
criterion = nn.MSELoss(size_average=False)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)

In [None]:
losses = vae.run(model, trn_loader, tst_loader, criterion, 
                 optimizer, epochs=50, plot_interval=25)
utils.plot_metric(losses['trn'], losses['tst'], 'Loss')

In [None]:
# Single Image
img_idx = 1
noise = 1. + torch.randn(1) * 1e-1 
recon, mean, var = vae.predict(model, inputs[img_idx])
out = vae.generate(model, mean*noise, var*noise)
utils.plot_tensor(inputs[img_idx], title="Input", fs=(4,4))
utils.plot_tensor(out, title="Generated", fs=(4,4))

In [None]:
# Batch
recon, mean, var = vae.predict(model, inputs)
out = vae.generate(model, mean, var)
utils.plot_batch(inputs)
utils.plot_batch(out)

## Classifier w VAE Augmentation

In [None]:
augmentor = model
classifier = cnn.CNN(in_shape=(3,32,32), n_classes=10).cuda()

In [None]:
epochs = 50
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
lr_adjuster = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999)
trainer = cnn.Trainer(optimizer, lr_adjuster, augmentor)

In [None]:
trainer.run(classifier, trn_loader, tst_loader, criterion, epochs)

In [None]:
utils.plot_metric(trainer.metrics['loss']['trn'], trainer.metrics['loss']['tst'], 'Loss')
utils.plot_metric(trainer.metrics['accuracy']['trn'], trainer.metrics['accuracy']['tst'], 'Accuracy')

## GAN

In [None]:
# https://arxiv.org/abs/1711.04340