In [None]:
#mixup.py
import math
import random

import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

from keras import backend as K
from keras.models import Model
from keras.regularizers import l2
from keras.utils import to_categorical
from keras.datasets import mnist, cifar10
from keras.layers import Activation, Input, Dense, Conv2D, LeakyReLU
from keras.layers import Dropout, BatchNormalization, Flatten, Reshape, SpatialDropout2D

In [None]:
def _mixup_batch(in_batch: np.ndarray, out_batch: np.ndarray, alpha: int = 1.0):
    """Mixup the batch by sampling from a beta distribution and 
    computing a weighted average of the first half of the batch with last half."""
    half = in_batch.shape[0] // 2
    mixed_ins = np.zeros((half,) + in_batch.shape[1:])
    mixed_outs = np.zeros((half,) + out_batch.shape[1:])
    for i in range(half):
        weight0 = np.random.beta(alpha, alpha)
        weight1 = 1 - weight0
        mixed_ins[i] = (in_batch[i, ...] * weight0) + (in_batch[half+i, ...] * weight1)
        mixed_outs[i] = (out_batch[i, ...] * weight0) + (out_batch[half+i, ...] * weight1)
    return mixed_ins, mixed_outs

In [None]:
def load_mnist():
    img_rows, img_cols = 28, 28

    # the data, shuffled and split between train and test sets
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1).astype('float32')
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1).astype('float32')
    x_train /= 128.0
    x_test /= 128.0
    x_train -= 1.0
    x_test -= 1.0
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)
    print('bounds:', np.min(x_train), np.max(x_train))
    print('x_train shape:', x_train.shape)
    print(x_train.shape[0], 'train samples')
    print(x_test.shape[0], 'test samples')
    print('y_train shape:', y_train.shape)
    print('y_train sum:', np.sum(y_train, axis=0))
    return (x_train, y_train), (x_test, y_test)

In [None]:
def load_cifar():
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255.0
    x_test /= 255.0
    #x_train -= 1.0
    #x_test -= 1.0
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)
    print('bounds:', np.min(x_train), np.max(x_train))
    print('x_train shape:', x_train.shape)
    print(x_train.shape[0], 'train samples')
    print(x_test.shape[0], 'test samples')	
    return (x_train, y_train), (x_test, y_test)

In [None]:
def _plot_batch(image_batch):
    batch_size = image_batch.shape[0]
    sqrt_batch = math.ceil(math.sqrt(batch_size))
    _, axes = plt.subplots(sqrt_batch, sqrt_batch, figsize=(18, 14))
    for i in range(batch_size):
        if image_batch.shape[-1] == 1:
            axes[i//sqrt_batch, i%sqrt_batch].imshow(image_batch[i, :, :, 0])
        else:
            axes[i//sqrt_batch, i%sqrt_batch].imshow(image_batch[i, ...])

In [None]:
def build_discriminative_model(in_shape, out_classes):
    d_input = Input(in_shape)
    H = Conv2D(128, (3, 3), strides=(1,1), padding='same')(d_input)
    H = BatchNormalization()(H)
    H = LeakyReLU(0.2)(H)
    H = Conv2D(256, (3, 3), strides=(2,2), padding='same')(H)
    H = BatchNormalization()(H)
    H = LeakyReLU(0.2)(H)
    H = Conv2D(256, (3, 3), strides=(2,2), padding='same')(H)
    H = BatchNormalization()(H)
    H = LeakyReLU(0.2)(H)
    H = Flatten()(H)
    H = Dense(128)(H)
    H = BatchNormalization()(H)
    d_V = Dense(out_classes, activation='softmax')(H)
    discriminator = Model(d_input, d_V)
    discriminator.compile(loss='categorical_crossentropy', optimizer='adam')
    discriminator.summary()
    return discriminator

In [None]:
def _accuracy_on_batch(discriminator, batch, y):
    y_hat = discriminator.predict(batch)
    y_hat_idx = np.argmax(y_hat, axis=-1)
    y_idx = np.argmax(y, axis=-1)
    diff = y_idx-y_hat_idx
    n_tot = y.shape[0]
    n_rig = (diff==0).sum()
    acc = n_rig*100.0/n_tot
    print(f'Accuracy: {acc:0.02f} pct ({n_rig} of {n_tot}).')

In [None]:
def _train_on_batches(discriminator, x_train, y_train, steps, batch_size, mixup=False, alpha=1.0, steps_per_print=100):
    losses = []
    if mixup:
        batch_size *= 2
    for s in range(steps):
        indexes = np.random.randint(0, x_train.shape[0], size=batch_size)
        image_batch = x_train[indexes, ...] 
        class_label = y_train[indexes, ...]
        if mixup:
            image_batch, class_label = _mixup_batch(image_batch, class_label, alpha)
        losses.append(discriminator.train_on_batch(image_batch, class_label))
        if s%steps_per_print == 0:
            _accuracy_on_batch(discriminator, image_batch, class_label)
    return losses

In [None]:
(mnist_x, mnist_y), (mnist_test_x, mnist_test_y) = load_mnist()
mnist_discriminator = build_discriminative_model((28, 28, 1), 10)
losses = {}
losses['erm'] = _train_on_batches(mnist_discriminator, mnist_x, mnist_y, 2000, 256, mixup=False)

In [None]:
mnist_discriminator = build_discriminative_model((28, 28, 1), 10)
losses['mixup_0.1'] = _train_on_batches(mnist_discriminator, mnist_x, mnist_y, 2000, 256, mixup=True, alpha=0.1)

In [None]:
mnist_discriminator = build_discriminative_model((28, 28, 1), 10)
losses['mixup_0.5'] = _train_on_batches(mnist_discriminator, mnist_x, mnist_y, 1000, 64, mixup=True, alpha=0.5)

In [None]:
for k in losses:
    plt.plot(losses[k], label=k)
plt.legend(list(losses.keys()), loc='upper right')

In [None]:
(cifar_x, cifar_y), _ = load_cifar()
cifar_discriminator = build_discriminative_model((32, 32, 3), 10)
cifar_losses = {}
cifar_losses['erm'] = _train_on_batches(cifar_discriminator, cifar_x, cifar_y, 5000, 256, mixup=False)

In [None]:
cifar_discriminator = build_discriminative_model((32, 32, 3), 10)
cifar_losses['mixup_0.1'] = _train_on_batches(cifar_discriminator, cifar_x, cifar_y, 5000, 256, mixup=True, alpha=0.1)

In [None]:
cifar_discriminator = build_discriminative_model((32, 32, 3), 10)
cifar_losses['mixup_0.3'] = _train_on_batches(cifar_discriminator, cifar_x, cifar_y, 2000, 128, mixup=True, alpha=0.3)

In [None]:
for k in cifar_losses:
    plt.plot(cifar_losses[k], label=k)
plt.legend(list(cifar_losses.keys()), loc='upper right')

In [None]:
cifar_discriminator = build_discriminative_model((32, 32, 3), 10)
cifar_losses['mixup_1.0'] = _train_on_batches(cifar_discriminator, cifar_x, cifar_y, 2000, 128, mixup=True, alpha=1)

In [None]:
mnist_discriminator = build_discriminative_model((28, 28, 1), 10)
_train_on_batches(mnist_discriminator, mnist_x, mnist_y, 1000, 9, mixup=False, steps_per_print=100)

In [None]:
mixup_mnist_model = build_discriminative_model((28, 28, 1), 10)
_train_on_batches(mixup_mnist_model, mnist_x, mnist_y, 10000, 9, mixup=True, alpha=500.0, steps_per_print=1000)

In [None]:
mixup_mnist_model = build_discriminative_model((28, 28, 1), 10)
_train_on_batches(mixup_mnist_model, mnist_x, mnist_y, 10000, 32, mixup=True, steps_per_print=500)