In [1]:
import os
import shutil
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import random

from keras.datasets import cifar10
from sklearn.model_selection import train_test_split
from PIL import ImageOps, ImageEnhance, ImageFilter, Image
from tqdm import tqdm_notebook
from functools import partial

plt.style.use('ggplot')
%matplotlib inline
plt.rcParams['figure.figsize'] = 20, 20
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

Using TensorFlow backend.


In [2]:
# Simple wrappers for functional-like model building

class Flatten:
    def __call__(self, input_tensor):
        return tf.reshape(input_tensor, (tf.shape(input_tensor)[0], -1))
    
class Dense:
    def __init__(self, in_planes, out_planes):
        self.in_planes = in_planes
        self.out_planes = out_planes
        unif_init_range = 1.0 / (out_planes)**(0.5)
        initializer = tf.random_uniform_initializer(-unif_init_range, unif_init_range)
        self.weights = tf.get_variable('weights', shape=[in_planes, out_planes], initializer=initializer,
                                       regularizer=tf.contrib.layers.l2_regularizer(scale=5e-5))
        self.biases = tf.get_variable('biases', shape=[out_planes], initializer=tf.constant_initializer(0.0),
                                      regularizer=tf.contrib.layers.l2_regularizer(scale=5e-5))

    def __call__(self, input_tensor):
        return tf.nn.xw_plus_b(input_tensor, self.weights, self.biases) 
    
class Conv2d:
    def __init__(self, in_planes, out_planes, filters: tuple, strides=(1, 1), padding='SAME', *args, **kwargs):
        self.strides = (1, *strides, 1)
        self.padding = padding.upper()
        self.out_planes = out_planes
        n = int(filters[0] * filters[1] * out_planes)
        initializer = tf.random_normal_initializer(stddev=np.sqrt(2.0 / n))
        self.args = args
        self.kwargs = kwargs
        self.kernels = tf.get_variable("kernels", shape=[filters[0], filters[1], in_planes, out_planes], initializer=initializer,
                                       regularizer=tf.contrib.layers.l2_regularizer(scale=5e-5))
        
    def __call__(self, input_tensor):
        return tf.nn.conv2d(input_tensor, filter=self.kernels, strides=self.strides, padding=self.padding, *self.args, *self.kwargs)
    
class Sequential:
    def __init__(self, layers):
        self.layers = layers
        
    def __call__(self, x):
        for l in self.layers:
            x = l(x)
        return x
    
class AvgPool:
    def __init__(self, kernel, strides=(1, 1), padding='VALID'):
        self.kernel = (1, *kernel, 1)
        self.strides = (1, *strides, 1)
        self.padding = padding.upper()
    
    def __call__(self, x):
        return tf.nn.avg_pool(x, ksize=self.kernel, strides=self.strides, padding=self.padding) 

In [3]:
class Bottleneck:
    outchannel_ratio = 4

    def __init__(self, inplanes, outplanes, alpha=(-1, 1), beta=(0, 1), prob=None, 
                 is_training=True, phase=None, strides=(1, 1), downsample=None):
        assert alpha[1] > alpha[0] and beta[1] > beta[0] 
        self.alpha = alpha
        self.beta = beta
        self.is_training = is_training
        self.prob = prob
        with tf.variable_scope("bn1"):
            self.bn1 = partial(tf.layers.batch_normalization, momentum=0.9, epsilon=1e-5, center=True, scale=True, 
                               training=phase, fused=True)
        with tf.variable_scope("conv1"):
            self.conv1 = Conv2d(inplanes, outplanes, (1, 1))
        with tf.variable_scope("bn2"):
            self.bn2 = partial(tf.layers.batch_normalization, momentum=0.9, epsilon=1e-5, center=True, scale=True, 
                               training=phase, fused=True)
        with tf.variable_scope("conv2"):
            self.conv2 = Conv2d(outplanes, (outplanes * 1), (3, 3), strides=strides)
        with tf.variable_scope("bn3"):
            self.bn3 = partial(tf.layers.batch_normalization, momentum=0.9, epsilon=1e-5, center=True, scale=True, 
                               training=phase, fused=True)
        with tf.variable_scope("conv3"):
            self.conv3 = Conv2d((outplanes * 1), outplanes * Bottleneck.outchannel_ratio, (1, 1))
        with tf.variable_scope("bn4"):
            self.bn4 = partial(tf.layers.batch_normalization, momentum=0.9, epsilon=1e-5, center=True, scale=True, 
                               training=phase, fused=True)
        self.relu = tf.nn.relu
        self.downsample = downsample
        self.strides = strides

    def __call__(self, x):
        out = self.bn1(x)
        out = self.conv1(out)
        
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
 
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        out = self.bn4(out)
            
        batch_size = tf.shape(out)[0]
            
        # Apply shake-drop regularization
        # Got from https://openreview.net/pdf?id=S1NHaMW0b
        if self.prob is not None:
            if self.is_training:
                bern_shape = [batch_size, 1, 1, 1]
                random_tensor = self.prob
                random_tensor += tf.random_uniform(bern_shape, dtype=tf.float32)
                binary_tensor = tf.floor(random_tensor)

                alpha_values = tf.random_uniform(
                    [batch_size, 1, 1, 1], minval=self.alpha[0], maxval=self.alpha[1],
                    dtype=tf.float32)
                beta_values = tf.random_uniform(
                    [batch_size, 1, 1, 1], minval=self.beta[0], maxval=self.beta[1],
                    dtype=tf.float32)

                rand_forward = (binary_tensor + alpha_values - binary_tensor * alpha_values)
                rand_backward = (binary_tensor + beta_values - binary_tensor * beta_values)
                out = out * rand_backward + tf.stop_gradient(out * rand_forward - out * rand_backward)
            else:
                expected_alpha = (self.alpha[1] + self.alpha[0]) / 2
                out = (self.prob + expected_alpha - self.prob * expected_alpha) * out
            
        if self.downsample is not None:
            shortcut = self.downsample(x)
            featuremap_size = tf.shape(shortcut)[1:3]
        else:
            shortcut = x
            featuremap_size = tf.shape(out)[1:3]
        
        residual_channel = tf.shape(out)[3]
        shortcut_channel = tf.shape(shortcut)[3]
        
        padding = tf.zeros((batch_size, featuremap_size[0], featuremap_size[1], residual_channel - shortcut_channel))
        return out + tf.concat([shortcut, padding], axis=3)

In [4]:
# Got from https://arxiv.org/pdf/1610.02915.pdf

class PyramidNet:
    def __init__(self, depth, alpha, num_classes, is_training=True, phase=None, shake_drop=False, reuse=False):
        self.inplanes = 16
        n = int((depth - 2) / 9)
        self.total_layers = n * 3
        
        # Best params due the shake-drop paper
        self.shake_drop = shake_drop
        self.alpha_shake = (-1, 1)
        self.beta_shake = (0, 1)
        self.p_l = 0.5
        self.current_layer = 0
        self.phase = phase

        self.is_training = is_training
        self.addrate = alpha / (3 * n * 1.0)

        self.input_featuremap_dim = self.inplanes
        with tf.variable_scope("pyramidnet_conv1"):
            self.conv1 = Conv2d(3, self.input_featuremap_dim, filters=(3, 3), strides=(1, 1))
        with tf.variable_scope("pyramidnet_bn1"):
            self.bn1 = partial(tf.layers.batch_normalization, momentum=0.9, epsilon=1e-5, center=True, scale=True, 
                               training=self.phase, fused=True)

        self.featuremap_dim = self.input_featuremap_dim 
        with tf.variable_scope("pyramidnet_layer1"):
            self.layer1 = self.pyramidal_make_layer(n)
        with tf.variable_scope("pyramidnet_layer2"):
            self.layer2 = self.pyramidal_make_layer(n, strides=(2, 2))
        with tf.variable_scope("pyramidnet_layer3"):
            self.layer3 = self.pyramidal_make_layer(n, strides=(2, 2))

        self.final_featuremap_dim = self.input_featuremap_dim
        with tf.variable_scope("pyramidnet_bn_final"):
            self.bn_final = partial(tf.layers.batch_normalization, momentum=0.9, epsilon=1e-5, center=True, scale=True, 
                               training=self.phase, fused=True)
        self.relu_final = tf.nn.relu
        self.avgpool = AvgPool((8, 8), strides=(8, 8))  # TODO: Think about SSP
        self.flatten = Flatten()
        with tf.variable_scope("pyramidnet_fc"):
            self.fc = Dense(self.final_featuremap_dim, num_classes)

    def pyramidal_make_layer(self, block_depth, strides=(1, 1)):
        downsample = None
        if strides != (1, 1):
            downsample = AvgPool((2, 2), strides=(2, 2))

        layers = []
        self.featuremap_dim = self.featuremap_dim + self.addrate
        prob = self.calc_prob()
        with tf.variable_scope("bottleneck_1"):
            layers.append(Bottleneck(self.input_featuremap_dim, int(round(self.featuremap_dim)), 
                                     alpha=self.alpha_shake, beta=self.beta_shake, prob=prob, # shake-drop regularization here
                                     is_training=self.is_training, phase=self.phase, strides=strides, downsample=downsample))
        for i in range(1, block_depth):
            temp_featuremap_dim = self.featuremap_dim + self.addrate
            prob = self.calc_prob()
            with tf.variable_scope(f"bottleneck_{i+1}"):
                layers.append(Bottleneck(int(round(self.featuremap_dim)) * Bottleneck.outchannel_ratio, int(round(temp_featuremap_dim)),
                                         alpha=self.alpha_shake, beta=self.beta_shake, prob=prob, # shake-drop regularization here
                                         is_training=self.is_training, phase=self.phase, strides=(1, 1)))
            self.featuremap_dim  = temp_featuremap_dim
        self.input_featuremap_dim = int(round(self.featuremap_dim)) * Bottleneck.outchannel_ratio

        return Sequential(layers)
    
    def calc_prob(self):
        if not self.shake_drop:
            return None
        self.current_layer += 1
        return 1 - (float(self.current_layer) / self.total_layers) * self.p_l

    def __call__(self, x):
        x = self.conv1(x)
        self.test_bn = self.bn1(x) 
        x = self.test_bn
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.bn_final(x)
        x = self.relu_final(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [5]:
# HParams for cifar10 normalization

# Got from https://arxiv.org/abs/1805.09501
IMAGE_SIZE = 32
MEANS = [0.49139968, 0.48215841, 0.44653091]
STDS = [0.24703223, 0.24348513, 0.26158784]
PARAMETER_MAX = 10

def random_flip(x):
    if np.random.rand(1)[0] > 0.5:
        return np.fliplr(x)
    return x

def zero_pad_and_crop(img, amount=4):
    padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2,
                           img.shape[2]))
    padded_img[amount:img.shape[0] + amount, amount:img.shape[1] + amount, :] = img
    top = np.random.randint(low=0, high=2 * amount)
    left = np.random.randint(low=0, high=2 * amount)
    new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :]
    return new_img

def create_cutout_mask(img_height, img_width, num_channels, size):
    assert img_height == img_width

    height_loc = np.random.randint(low=0, high=img_height)
    width_loc = np.random.randint(low=0, high=img_width)

    upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
    lower_coord = (min(img_height, height_loc + size // 2),
                   min(img_width, width_loc + size // 2))
    mask_height = lower_coord[0] - upper_coord[0]
    mask_width = lower_coord[1] - upper_coord[1]

    mask = np.ones((img_height, img_width, num_channels))
    zeros = np.zeros((mask_height, mask_width, num_channels))
    mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = (zeros)
    return mask, upper_coord, lower_coord

def cutout_numpy(img, size=16):
    img_height, img_width, num_channels = (img.shape[0], img.shape[1], img.shape[2])
    assert len(img.shape) == 3
    mask, _, _ = create_cutout_mask(img_height, img_width, num_channels, size)
    return img * mask

def float_parameter(level, maxval):
    return float(level) * maxval / PARAMETER_MAX

def int_parameter(level, maxval):
    return int(level * maxval / PARAMETER_MAX)

def pil_wrap(img):
    return Image.fromarray(np.uint8((img * STDS + MEANS) * 255.0)).convert('RGBA')

def pil_unwrap(pil_img):
    pic_array = (np.array(pil_img.getdata()).reshape((32, 32, 4)) / 255.0)
    i1, i2 = np.where(pic_array[:, :, 3] == 0)
    pic_array = (pic_array[:, :, :3] - MEANS) / STDS
    pic_array[i1, i2] = [0, 0, 0]
    return pic_array

def apply_policy(policy, img):
    pil_img = pil_wrap(img)
    for xform in policy:
        assert len(xform) == 3
        name, probability, level = xform
        xform_fn = NAME_TO_TRANSFORM[name].pil_transformer(probability, level)
        pil_img = xform_fn(pil_img)
    return pil_unwrap(pil_img)

class TransformFunction(object):
    def __init__(self, func, name):
        self.f = func
        self.name = name

    def __repr__(self):
        return '<' + self.name + '>'

    def __call__(self, pil_img):
        return self.f(pil_img)

class TransformT(object):
    def __init__(self, name, xform_fn):
        self.name = name
        self.xform = xform_fn

    def pil_transformer(self, probability, level):
        def return_function(im):
            if random.random() < probability:
                im = self.xform(im, level)
            return im
        name = self.name + '({:.1f},{})'.format(probability, level)
        return TransformFunction(return_function, name)

    def do_transform(self, image, level):
        f = self.pil_transformer(PARAMETER_MAX, level)
        return pil_unwrap(f(pil_wrap(image)))

identity = TransformT('identity', lambda pil_img, level: pil_img)
flip_lr = TransformT('FlipLR', lambda pil_img, level: pil_img.transpose(Image.FLIP_LEFT_RIGHT))
flip_ud = TransformT('FlipUD', lambda pil_img, level: pil_img.transpose(Image.FLIP_TOP_BOTTOM))
auto_contrast = TransformT('AutoContrast', lambda pil_img, level: ImageOps.autocontrast(pil_img.convert('RGB')).convert('RGBA'))
equalize = TransformT('Equalize', lambda pil_img, level: ImageOps.equalize(pil_img.convert('RGB')).convert('RGBA'))
invert = TransformT('Invert', lambda pil_img, level: ImageOps.invert(pil_img.convert('RGB')).convert('RGBA'))
blur = TransformT('Blur', lambda pil_img, level: pil_img.filter(ImageFilter.BLUR))
smooth = TransformT('Smooth', lambda pil_img, level: pil_img.filter(ImageFilter.SMOOTH))

def _rotate_impl(pil_img, level):
    degrees = int_parameter(level, 30)
    if random.random() > 0.5:
        degrees = -degrees
    return pil_img.rotate(degrees)

rotate = TransformT('Rotate', _rotate_impl)

def _posterize_impl(pil_img, level):
    level = int_parameter(level, 4)
    return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA')

posterize = TransformT('Posterize', _posterize_impl)

def _shear_x_impl(pil_img, level):
    level = float_parameter(level, 0.3)
    if random.random() > 0.5:
        level = -level
    return pil_img.transform((32, 32), Image.AFFINE, (1, level, 0, 0, 1, 0))

shear_x = TransformT('ShearX', _shear_x_impl)

def _shear_y_impl(pil_img, level):
    level = float_parameter(level, 0.3)
    if random.random() > 0.5:
        level = -level
    return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, level, 1, 0))

shear_y = TransformT('ShearY', _shear_y_impl)

def _translate_x_impl(pil_img, level):
    level = int_parameter(level, 10)
    if random.random() > 0.5:
        level = -level
    return pil_img.transform((32, 32), Image.AFFINE, (1, 0, level, 0, 1, 0))

translate_x = TransformT('TranslateX', _translate_x_impl)

def _translate_y_impl(pil_img, level):
    level = int_parameter(level, 10)
    if random.random() > 0.5:
        level = -level
    return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, 0, 1, level))

translate_y = TransformT('TranslateY', _translate_y_impl)

def _crop_impl(pil_img, level, interpolation=Image.BILINEAR):
    cropped = pil_img.crop((level, level, IMAGE_SIZE - level, IMAGE_SIZE - level))
    resized = cropped.resize((IMAGE_SIZE, IMAGE_SIZE), interpolation)
    return resized

crop_bilinear = TransformT('CropBilinear', _crop_impl)

def _solarize_impl(pil_img, level):
    level = int_parameter(level, 256)
    return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA')

solarize = TransformT('Solarize', _solarize_impl)

def _cutout_pil_impl(pil_img, level):
    size = int_parameter(level, 20)
    if size <= 0:
        return pil_img
    img_height, img_width, num_channels = (32, 32, 3)
    _, upper_coord, lower_coord = (create_cutout_mask(img_height, img_width, num_channels, size))
    pixels = pil_img.load()  # create the pixel map
    for i in range(upper_coord[0], lower_coord[0]):  # for every col:
        for j in range(upper_coord[1], lower_coord[1]):  # For every row
            pixels[i, j] = (125, 122, 113, 0)  # set the colour accordingly
    return pil_img

cutout = TransformT('Cutout', _cutout_pil_impl)

def _enhancer_impl(enhancer):
    def impl(pil_img, level):
        v = float_parameter(level, 1.8) + .1  # going to 0 just destroys it
        return enhancer(pil_img).enhance(v)
    return impl

color = TransformT('Color', _enhancer_impl(ImageEnhance.Color))
contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast))
brightness = TransformT('Brightness', _enhancer_impl(ImageEnhance.Brightness))
sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness))

ALL_TRANSFORMS = [flip_lr, flip_ud, auto_contrast, equalize, invert, rotate,
    posterize, crop_bilinear, solarize, color, contrast, brightness, sharpness,
    shear_x, shear_y, translate_x, translate_y, cutout, blur, smooth]
NAME_TO_TRANSFORM = {t.name: t for t in ALL_TRANSFORMS}
TRANSFORM_NAMES = NAME_TO_TRANSFORM.keys()

def good_policies():
    """Format: (name, probability, level) """
    exp0_0 = [
        [('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
        [('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
        [('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
        [('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
        [('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]]
    exp0_1 = [
        [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
        [('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
        [('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
        [('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
        [('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]]
    exp0_2 = [
        [('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)],
        [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)],
        [('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)],
        [('Equalize', 0.7, 5), ('Invert', 0.1, 3)],
        [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]]
    exp0_3 = [
        [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)],
        [('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)],
        [('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)],
        [('TranslateY', 0.2, 7), ('Color', 0.9, 6)],
        [('Equalize', 0.7, 6), ('Color', 0.4, 9)]]
    exp1_0 = [
        [('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
        [('Color', 0.4, 3), ('Brightness', 0.6, 7)],
        [('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
        [('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
        [('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]]
    exp1_1 = [
        [('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)],
        [('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)],
        [('Solarize', 0.3, 5), ('Equalize', 0.6, 5)],
        [('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)],
        [('Brightness', 0.0, 8), ('Color', 0.8, 8)]]
    exp1_2 = [
        [('Solarize', 0.2, 6), ('Color', 0.8, 6)],
        [('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)],
        [('Solarize', 0.4, 1), ('Equalize', 0.6, 5)],
        [('Brightness', 0.0, 0), ('Solarize', 0.5, 2)],
        [('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]]
    exp1_3 = [
        [('Contrast', 0.7, 5), ('Brightness', 0.0, 2)],
        [('Solarize', 0.2, 8), ('Solarize', 0.1, 5)],
        [('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)],
        [('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)],
        [('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]]
    exp1_4 = [
        [('Brightness', 0.0, 7), ('Equalize', 0.4, 7)],
        [('Solarize', 0.2, 5), ('Equalize', 0.7, 5)],
        [('Equalize', 0.6, 8), ('Color', 0.6, 2)],
        [('Color', 0.3, 7), ('Color', 0.2, 4)],
        [('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]]
    exp1_5 = [
        [('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)],
        [('ShearY', 0.6, 5), ('Equalize', 0.6, 5)],
        [('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)],
        [('Equalize', 0.8, 8), ('Equalize', 0.7, 7)],
        [('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]]
    exp1_6 = [
        [('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)],
        [('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)],
        [('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)],
        [('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)],
        [('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]]
    exp2_0 = [
        [('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
        [('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
        [('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
        [('Brightness', 0.9, 6), ('Color', 0.2, 8)],
        [('Solarize', 0.5, 2), ('Invert', 0.0, 3)]]
    exp2_1 = [
        [('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)],
        [('Cutout', 0.2, 4), ('Equalize', 0.1, 1)],
        [('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)],
        [('Color', 0.1, 8), ('ShearY', 0.2, 3)],
        [('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]]
    exp2_2 = [
        [('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)],
        [('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)],
        [('Equalize', 0.5, 0), ('Solarize', 0.6, 6)],
        [('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)],
        [('Equalize', 0.8, 2), ('Invert', 0.4, 0)]]
    exp2_3 = [
        [('Equalize', 0.9, 5), ('Color', 0.7, 0)],
        [('Equalize', 0.1, 1), ('ShearY', 0.1, 3)],
        [('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)],
        [('Brightness', 0.5, 1), ('Contrast', 0.1, 7)],
        [('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]]
    exp2_4 = [
        [('Solarize', 0.2, 3), ('ShearX', 0.0, 0)],
        [('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)],
        [('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)],
        [('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)],
        [('Equalize', 0.8, 6), ('Invert', 0.3, 6)]]
    exp2_5 = [
        [('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)],
        [('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)],
        [('ShearX', 0.0, 3), ('Posterize', 0.0, 3)],
        [('Solarize', 0.4, 3), ('Color', 0.2, 4)],
        [('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]]
    exp2_6 = [
        [('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)],
        [('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)],
        [('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)],
        [('Equalize', 0.1, 0), ('Equalize', 0.0, 6)],
        [('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]]
    exp2_7 = [
        [('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)],
        [('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)],
        [('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)],
        [('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)],
        [('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]]
    exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3
    exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6
    exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + exp2_7
    return  exp0s + exp1s + exp2s

In [6]:
# Common helpers

def preprocess_data(datasets):
    out = []
    for data in datasets:
        data = data.reshape(-1, 32, 32, 3)
        data = data / 255.0
        out.append((data - MEANS) / STDS)
    return out

def aug_batch(policies, batch):
    epoch_policy = policies[np.random.choice(len(policies))]
    out = np.zeros_like(batch).astype(np.float32)
    for i, image in enumerate(batch):
        image = apply_policy(epoch_policy, image)
        image = random_flip(zero_pad_and_crop(image, 4))
        image = cutout_numpy(image)
        out[i] = image
    return out

def cosine_lr(learning_rate, epoch, iteration, batches_per_epoch, total_epochs=1800):
    t_total = total_epochs * batches_per_epoch
    t_cur = float(epoch * batches_per_epoch + iteration)
    return 0.5 * learning_rate * (1 + np.cos(np.pi * t_cur / t_total))

def get_lr(curr_epoch, iteration=None, initial_lr=0.05, batch_size=128, num_epochs=800):
    assert iteration is not None
    batches_per_epoch = int(5000 / batch_size)
    lr = cosine_lr(initial_lr, curr_epoch, iteration, batches_per_epoch, num_epochs)
    return lr

def save_results(sess, saver, batch_size=32):
    test_results = []
    accuracy_results = []

    index_test = np.arange(len(x_test))
    num_batches_test = int(len(index_test) / batch_size)
    batch_indexes_test = np.array_split(index_test, num_batches_test)

    for batch_index in tqdm_notebook(batch_indexes_test, leave=False):
        feed_dict = {x: x_test[batch_index], y: y_test[batch_index], phase: 0, learning_rate:0.0}
        out, = sess.run([logits], feed_dict=feed_dict)
        for i, batch_i in enumerate(batch_index):
            y_pred = np.argmax(out[i])
            test_results.append([batch_i, y_pred])
            accuracy_results.append([y_pred, y_test[batch_i]])

    accuracy_results = np.array(accuracy_results)
    test_results = np.array(test_results).astype(np.int)
    correct_prediction = np.equal(accuracy_results[:, 0], accuracy_results[:, 1])
    np.savetxt("results.csv", test_results, delimiter=",", fmt='%d')
    print(f"Saved with accuracy: {np.mean(correct_prediction.astype(np.float32)):.6}")
    
def restore():
    with tf.device('/cpu:0'):
        saver = tf.train.Saver()
    with tf.device("/gpu:0"):
        sess = tf.Session()
        saver.restore(sess, "./model.ckpt")
    return sess

In [7]:
# Load data

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train, y_test = np.squeeze(y_train), np.squeeze(y_test)
x_train, x_test = preprocess_data([x_train, x_test])

In [8]:
tf.reset_default_graph()

need_to_restore_last = True

x = tf.placeholder(tf.float32, [None, 32, 32, 3])
y = tf.placeholder(tf.int64, [None])
learning_rate = tf.placeholder(tf.float32)
phase = tf.placeholder(tf.bool, name='phase')  # 'is_training' for shake-drop regularization to prevent extra tf.cond, 'phase' is for BN 

model_args = {'depth': 272, 'alpha': 200, 'num_classes': 10, 'shake_drop': True}

with tf.variable_scope('model', use_resource=False) as scope:
    model = PyramidNet(is_training=True, phase=phase, **model_args)
    scope.reuse_variables()
    eval_model = PyramidNet(is_training=False, phase=phase, **model_args)
    
logits = model(x)
total_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
mean_loss = tf.reduce_mean(total_loss)
mean_loss += tf.losses.get_regularization_loss()
correct_prediction = tf.equal(tf.argmax(logits, 1), y)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

eval_logits = eval_model(x)
eval_correct_prediction = tf.equal(tf.argmax(eval_logits, 1), y)
eval_accuracy = tf.reduce_mean(tf.cast(eval_correct_prediction, tf.float32))

train_summaries = tf.summary.merge([tf.summary.scalar('loss', mean_loss), tf.summary.scalar('accuracy', accuracy)])
eval_summaries = tf.summary.merge([tf.summary.scalar('eval_accuracy', eval_accuracy)])

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9, use_nesterov=True)
    optimizer = tf.contrib.estimator.clip_gradients_by_norm(optimizer, clip_norm=5.0)
    train_step = optimizer.minimize(mean_loss)

with tf.device('/cpu:0'):
    saver = tf.train.Saver(max_to_keep=2)

if not need_to_restore_last:
    with tf.device("/gpu:0"):
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))
else:
    sess = restore()

INFO:tensorflow:Restoring parameters from ./model.ckpt


In [None]:
initial_epoch = 1400
epochs = 1800
batch_size = 16
current_lr = 0.05

index = np.arange(len(x_train))
num_batches = int(len(index) / batch_size)

index_test = np.arange(len(x_test))
num_batches_test = int(len(index_test) / batch_size)
batch_indexes_test = np.array_split(index_test, num_batches_test)

policies = good_policies()
iteration = (len(x_train) * initial_epoch) // batch_size
shutil.rmtree("./logs/")
summary_writer = tf.summary.FileWriter(f"./logs/pyramidnet_{model_args['depth']}/", graph=sess.graph)
    
for e in tqdm_notebook(range(initial_epoch, epochs), desc='Epochs:'):
    
    # set warmup stages without augment and shuffle data every epoch
    augment = True
    np.random.shuffle(index)
    batch_indexes = np.array_split(index, num_batches)

    # train epoch
    batch_gen = tqdm_notebook(enumerate(batch_indexes), leave=False)
    losses = []
    for i, batch_index in batch_gen:
        if augment:
            current_lr = get_lr(e, i + 1, num_epochs=epochs, batch_size=batch_size)
            x_batch = aug_batch(policies, x_train[batch_index])
            feed_dict = {x: x_batch, y: y_train[batch_index], phase: 1, learning_rate: current_lr}
        else:
            feed_dict = {x: x_train[batch_index], y: y_train[batch_index], phase: 1, learning_rate: current_lr}

        scores, loss, acc, _, summ_data = sess.run([logits, mean_loss, accuracy, train_step, train_summaries], feed_dict=feed_dict)
        summary_writer.add_summary(summ_data, iteration)
        iteration += 1
        batch_gen.set_postfix(loss=loss, acc=acc, lr=current_lr)
        losses.append(loss)
        
    # validation step
    val_accuracy = []
    for i, batch_index in tqdm_notebook(enumerate(batch_indexes_test), leave=False):
        feed_dict = {x: x_test[batch_index], y: y_test[batch_index], phase: 0, learning_rate:0.0}
        acc, summ_data = sess.run([eval_accuracy, eval_summaries], feed_dict=feed_dict)
        summary_writer.add_summary(summ_data, (e * int(len(index_test) / batch_size)) + i)
        val_accuracy.append(acc)
    
    print(f'train loss: {np.mean(losses):.6}, eval accuracy: {np.mean(val_accuracy):.6}')
    
    if e % 50 == 0 and e > 0:
        save_results(sess, saver)

In [12]:
test_results = []
accuracy_results = []

index_test = np.arange(len(x_test))
num_batches_test = int(len(index_test) / batch_size)
batch_indexes_test = np.array_split(index_test, num_batches_test)

for batch_index in tqdm_notebook(batch_indexes_test, leave=False):
    feed_dict = {x: x_test[batch_index], y: y_test[batch_index], phase: 0, learning_rate:0.0}
    out, = sess.run([logits], feed_dict=feed_dict)
    for i, batch_i in enumerate(batch_index):
        y_pred = np.argmax(out[i])
        test_results.append([batch_i, y_pred])
        accuracy_results.append([y_pred, y_test[batch_i]])

accuracy_results = np.array(accuracy_results)
test_results = np.array(test_results).astype(np.int)
correct_prediction = np.equal(accuracy_results[:, 0], accuracy_results[:, 1])
np.savetxt("results.csv", test_results, delimiter=",", fmt='%d')
print(f"Saved with accuracy: {np.mean(correct_prediction.astype(np.float32)):.6}")

Saved with accuracy: 0.9751
