In [0]:
# All Includes
from sklearn.metrics import classification_report
from keras.models import Sequential, load_model, clone_model
from keras.layers import Dense, Activation, Convolution2D, MaxPooling2D, Dropout, Flatten, BatchNormalization
from keras import models, layers, optimizers, utils
import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw, PIL.Image
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat
import time
import random

**=== Child Model ===**

In [0]:


def create_model(trainX, n_classes):
    # https://stats.stackexchange.com/questions/272607/cifar-10-cant-get-above-60-accuracy-keras-with-tensorflow-backend
    model = Sequential()
    model.add(Convolution2D(input_shape=trainX[0,:,:,:].shape, filters=96, kernel_size=(3,3)))
    model.add(Activation('relu'))
    model.add(Convolution2D(filters=96, kernel_size=(3,3), strides=2))
    model.add(Activation('relu'))
    model.add(Dropout(0.2))
    model.add(Convolution2D(filters=192, kernel_size=(3,3)))
    model.add(Activation('relu'))
    model.add(Convolution2D(filters=192, kernel_size=(3,3), strides=2))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Flatten())
    model.add(BatchNormalization())
    model.add(Dense(256))
    model.add(Activation('relu'))
    model.add(Dense(n_classes, activation="softmax"))
    optimizer = optimizers.Adadelta(lr=0.05, rho=0.95, epsilon=None, decay=0.0)
    model.compile(optimizer, 'categorical_crossentropy', ['accuracy'])
    return model

  

  
def model_cond_accuracy(model, X, y):
    y_prob = model.predict(X)
    y_classes = y_prob.argmax(axis=-1).tolist()
    y_test = y.argmax(axis=-1).tolist()
    total = [0] * 10
    counts = [0] * 10
    for i in range(len(y_classes)):
      if y_classes[i] == y_test[i]:
        total[y_test[i]] += 1
      counts[y_test[i]] += 1
    acc = [0.0] * 10
    for i in range(10):
      if 0 != counts[i]:
        acc[i] = total[i] / counts[i]
    return acc

def model_fit(model, gen, val_data, nbatches, epochs):
    history = model.fit_generator(
      gen, nbatches, epochs, verbose=1, use_multiprocessing=True, validation_data =  val_data)
    return history

def model_evaluate(model, X, y):
  return model.evaluate(X, y, verbose=0)[1]

**=== Transforms ===**

In [0]:
# Code below adapted from augmentation_transforms.py
# Modified to support transforms at the image class level
# Original copywright below:

# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================



IMAGE_SIZE = 32
MEANS = [0.49139968, 0.48215841, 0.44653091]
STDS = [0.24703223, 0.24348513, 0.26158784]
PARAMETER_MAX = 10  # What is the max 'level' a transform could be predicted

def pil_wrap(img):
  """Convert the `img` numpy tensor to a PIL Image."""
  return PIL.Image.fromarray(
      np.uint8((img * STDS + MEANS) * 255.0)).convert('RGBA')


def pil_unwrap(pil_img):
  """Converts the PIL img to a numpy array."""
  pic_array = (np.array(pil_img.getdata()).reshape((32, 32, 4)) / 255.0)
  i1, i2 = np.where(pic_array[:, :, 3] == 0)
  pic_array = (pic_array[:, :, :3] - MEANS) / STDS
  pic_array[i1, i2] = [0, 0, 0]
  return pic_array

class Operation:
    def __init__(self, t, p = 0.5):

        self.prob = p
        self.magnitude = t[1]
        self.transformation = t[0]

    def __call__(self, X, Y):
        _X = []
        #self.magnitude = random.randint(0,9)
        for x,y in zip(X,Y):
            if np.random.rand() < self.prob:
                x = pil_wrap(x)
                x = self.transformation[np.argmax(y)](x, self.magnitude)
                x = pil_unwrap(x)
            _X.append(np.array(x))
        return np.array(_X)
    

class Transform:
    def __init__(self, *operations):
        self.operations = operations

    def __call__(self, X, Y):
        for op in self.operations:
            X = op(X, Y)
        return X


def autoaugment(transforms, X, y, batch_size):
    while True:
        ix = np.arange(len(X))
        np.random.shuffle(ix)
        for i in range(len(X) // batch_size):
            _ix = ix[i*batch_size:(i+1)*batch_size]
            _X = X[_ix]
            _y = y[_ix]
            if 0 != len(transforms):
              transform = np.random.choice(transforms)
              _X = transform(_X, _y)
            yield _X, _y

# modified from https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
def create_cutout_mask(img_height, img_width, num_channels, size):
  """Creates a zero mask used for cutout of shape `img_height` x `img_width`.

  Args:
    img_height: Height of image cutout mask will be applied to.
    img_width: Width of image cutout mask will be applied to.
    num_channels: Number of channels in the image.
    size: Size of the zeros mask.

  Returns:
    A mask of shape `img_height` x `img_width` with all ones except for a
    square of zeros of shape `size` x `size`. This mask is meant to be
    elementwise multiplied with the original image. Additionally returns
    the `upper_coord` and `lower_coord` which specify where the cutout mask
    will be applied.
  """
  assert img_height == img_width

  # Sample center where cutout mask will be applied
  height_loc = np.random.randint(low=0, high=img_height)
  width_loc = np.random.randint(low=0, high=img_width)

  # Determine upper right and lower left corners of patch
  upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
  lower_coord = (min(img_height, height_loc + size // 2),
                 min(img_width, width_loc + size // 2))
  mask_height = lower_coord[0] - upper_coord[0]
  mask_width = lower_coord[1] - upper_coord[1]
  assert mask_height > 0
  assert mask_width > 0

  mask = np.ones((img_height, img_width, num_channels))
  zeros = np.zeros((mask_height, mask_width, num_channels))
  mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = (
      zeros)
  return mask, upper_coord, lower_coord

def cutout_numpy(img, size=16):
  """Apply cutout with mask of shape `size` x `size` to `img`.

  The cutout operation is from the paper https://arxiv.org/abs/1708.04552.
  This operation applies a `size`x`size` mask of zeros to a random location
  within `img`.

  Args:
    img: Numpy image that cutout will be applied to.
    size: Height/width of the cutout mask that will be

  Returns:
    A numpy tensor that is the result of applying the cutout mask to `img`.
  """
  img_height, img_width, num_channels = (img.shape[0], img.shape[1],
                                         img.shape[2])
  assert len(img.shape) == 3
  mask, _, _ = create_cutout_mask(img_height, img_width, num_channels, size)
  return img * mask

def float_parameter(level, maxval):
  """Helper function to scale `val` between 0 and maxval .

  Args:
    level: Level of the operation that will be between [0, `PARAMETER_MAX`].
    maxval: Maximum value that the operation can have. This will be scaled
      to level/PARAMETER_MAX.

  Returns:
    A float that results from scaling `maxval` according to `level`.
  """
  return float(level) * maxval / PARAMETER_MAX


def int_parameter(level, maxval):
  """Helper function to scale `val` between 0 and maxval .

  Args:
    level: Level of the operation that will be between [0, `PARAMETER_MAX`].
    maxval: Maximum value that the operation can have. This will be scaled
      to level/PARAMETER_MAX.

  Returns:
    An int that results from scaling `maxval` according to `level`.
  """
  return int(level * maxval / PARAMETER_MAX)

def _cutout_pil_impl(pil_img, level):
  """Apply cutout to pil_img at the specified level."""
  size = int_parameter(level, 20)
  if size <= 0:
    return pil_img
  img_height, img_width, num_channels = (32, 32, 3)
  _, upper_coord, lower_coord = (
      create_cutout_mask(img_height, img_width, num_channels, size))
  pixels = pil_img.load()  # create the pixel map
  for i in range(upper_coord[0], lower_coord[0]):  # for every col:
    for j in range(upper_coord[1], lower_coord[1]):  # For every row
      pixels[i, j] = (125, 122, 113, 0)  # set the colour accordingly
  return pil_img

def _enhancer_impl(enhancer):
  """Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL."""
  def impl(pil_img, level):
    v = float_parameter(level, 1.8) + .1  # going to 0 just destroys it
    return enhancer(pil_img).enhance(v)
  return impl

# =============================================================================

def ShearX(img, v):  # [-0.3, 0.3]
    if random.random() > 0.5:
      v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
  
def ShearY(img, v):  # [-0.3, 0.3]
    if random.random() > 0.5:
      v = -v
    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))  
  
def TranslateX(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    if random.random() > 0.5:
      v = -v
    return img.transform((32, 32), PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))

def TranslateY(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
    if random.random() > 0.5:
      v = -v
    return img.transform((32, 32), PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))

def Rotate(img, v):  # [-30, 30]
    if random.random() > 0.5:
      v = -v
    return img.rotate(v)

def AutoContrast(img, _):
    return PIL.ImageOps.autocontrast(img.convert('RGB')).convert('RGBA')

def Invert(img, _):
    return PIL.ImageOps.invert(img.convert('RGB')).convert('RGBA')

def Equalize(img, _):
    return PIL.ImageOps.equalize(img.convert('RGB')).convert('RGBA')

def Flip_LR(img, _):  # not from the paper
    return img.transpose(PIL.Image.FLIP_LEFT_RIGHT)

def Flip_UD(img, _):
    return img.transpose(PIL.Image.FLIP_TOP_BOTTOM)
  
def Solarize(img, v):  # [0, 256]
    v = int_parameter(v, 256)
    return PIL.ImageOps.solarize(img.convert('RGB'), 256 - v).convert('RGBA')

def Posterize(img, v):  # [4, 8]
    v = int_parameter(v, 4)
    return PIL.ImageOps.posterize(img.convert('RGB'), 4 - v).convert('RGBA')

def Contrast(img, v):  # [0.1,1.9]
    return _enhancer_impl(PIL.ImageEnhance.Contrast)(img, v)

def Blur(img, v):
    return img.filter(PIL.ImageFilter.BLUR)
  
def Color(img, v):  # [0.1,1.9]
    return _enhancer_impl(PIL.ImageEnhance.Color)(img, v)

def Smooth(img, v):
    return img.filter(PIL.ImageFilter.SMOOTH)
  
def Brightness(img, v):  # [0.1,1.9]
    return _enhancer_impl(PIL.ImageEnhance.Brightness)(img, v)

def Sharpness(img, v):  # [0.1,1.9]
    return _enhancer_impl(PIL.ImageEnhance.Sharpness)(img, v)

def Cutout(img, v):  # [0, 60] => percentage: [0, 0.2]
    return _cutout_pil_impl(img, v)

def Crop(img, v, interpolation=PIL.Image.BILINEAR):
    cropped = img.crop((v, v, IMAGE_SIZE - v, IMAGE_SIZE - v))
    resized = cropped.resize((IMAGE_SIZE, IMAGE_SIZE), interpolation)
    return resized

def Identity(img, v):
  return img

  
opmap = {
    'Flip_LR' : Flip_LR,
    'Flip_UD' : Flip_UD,
    'AutoContrast' : AutoContrast,
    'Equalize' : Equalize,
    'Invert' : Invert,
    'Rotate' : Rotate,
    'Poserize' : Posterize,
    'Crop' : Crop,
    'Solarize' : Solarize,
    'Color' : Color,
    'Contrast' : Contrast,
    'Brightness' : Brightness,
    'Sharpness' : Sharpness,
    'ShearX' : ShearX,
    'ShearY' : ShearY,
    'TranslateX' : TranslateX,
    'TranslateY' : TranslateY,
    'Cutout' : Cutout,
    'Blur' : Blur,
    'Smooth' : Smooth
    
}

**=== Get Ready... ===**

In [4]:
# Load CIFAR-10
from keras.datasets import cifar10
(X, y), (X_test, y_test) = cifar10.load_data()



# Create the Reduced CIFAR-10 dataset
#ix = np.random.choice(len(X), 4000, False)
#x_reduced_train = X[ix]
#y_reduced_train = y[ix]

# Shuffle the training data
shuffling = np.random.permutation(X.shape[0])   
X = X[shuffling, :]
y = y[shuffling]

# Split Training --> Training + Validation
nTrain = int(0.9 * X.shape[0])
X_train = X[0:nTrain, :, :, :]
y_train = y[:nTrain]

X_validation = X[nTrain:, :, :, :]
y_validation = y[nTrain:]

print(X_train.shape)
print(X_validation.shape)

y_train = utils.to_categorical(y_train)
y_validation = utils.to_categorical(y_validation)
y_test = utils.to_categorical(y_test)


categories = ['airplane', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
(45000, 32, 32, 3)
(5000, 32, 32, 3)


In [0]:
batch_size = 250

def train_model(model, aug, epochs):

  tic = time.clock()
  model_fit(model, aug, (X_validation, y_validation), len(X_train) // batch_size, epochs)
  toc = time.clock()
  
  accuracy = model_evaluate(model, X_test, y_test)
  
  print('Test accuracy: %.3f (elaspsed time: %ds)' % (accuracy, (toc-tic)))
  acc = model_cond_accuracy(model, X_test, y_test)
  
  print("Accuracy")
  for cat, a in zip(categories, acc):
    print(cat, a)

**Baseline Training Without Data Augmentation**

In [29]:
ops = [Operation(([Identity] * 10, 0), 0)]
transform1 = Transform(*ops)

transforms = [transform1]
aug = autoaugment(transforms, X_train, y_train, batch_size) 

model0 = create_model(X_train, 10)

history0 = train_model(model0, aug, 50)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
Test accuracy: 0.745 (elaspsed time: 722s)
Accuracy
airplane 0.742
auto 0.86
bird 0.63
cat 0.583
deer 0.681
dog 0.701
frog 0.823
horse 0.792
ship 0.835
truck 0.798


**Train for each individual image transform**
Do transforms have significantly different effects for each image class?

In [18]:
# Try one to see how many epochs to try... 100 seems to work well; need lots of extra training!
op = opmap['Flip_LR']
print('=========================================================')
print('=========================================================')
print('=== ', 'Flip_LR', ' ===')

# 50% prob of transform
ops = [Operation(([op] * 10, 5), 0.5)]
transform = [Transform(*ops)]

aug = autoaugment(transform, X_train, y_train, batch_size) 

model = create_model(X_train, 10)

history = train_model(model, aug, 100)
  

===  Flip_LR  ===
Epoch 1/100


TypeError: ignored

In [13]:
for name, op in opmap.items():
  print('=========================================================')
  print('=========================================================')
  print('=== ', name, ' ===')

  # 50% prob of transform
  ops = [Operation(([op] * 10, 5), 0.5)]
  transform = [Transform(*ops)]

  aug = autoaugment(transform, X_train, y_train, batch_size) 
  
  model = create_model(X_train, 10)

  history = train_model(model, aug, 100)

===  Flip_LR  ===
Epoch 1/100

Process ForkPoolWorker-23:
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "/usr/local/lib/python3.6/dist-packages/keras/utils/data_utils.py", line 626, in next_sample
    return six.next(_SHARED_SEQUENCES[uid])
  File "<ipython-input-4-d4f43fd572ab>", line 62, in autoaugment
    _X = transform(_X, _y)
  File "<ipython-input-4-d4f43fd572ab>", line 48, in __call__
    X = op(X, Y)
  File "<ipython-input-4-d4f43fd572ab>", line 37, in __call__
    x = pil_unwrap(x)
KeyboardInterrupt
  File "<ipython-input-4-d4f43fd572ab>", line 19, in pil_unwrap
    pic_array = (pic_array[:, :, :3] - MEANS) / STDS


Epoch 1/100


KeyboardInterrupt: ignored

**AutoAugment Policy**

In [7]:
# Utility function to create policy, tform, below
def mk_op2(op1, p1, v1, op2, p2, v2):
  ops = [Operation(([op1] * 10, v1), p1), Operation(([op2] * 10, v2), p2)]
  return Transform(*ops)

# Duplicate the AutoAugment CIFAR-10 policy selected by concatenations of AutoAugment
tform = [
    #0_0
    mk_op2(Invert, 0.1, 7, Contrast, 0.2, 6),
    mk_op2(Rotate, 0.7, 2, TranslateX, 0.3, 9),
    mk_op2(Sharpness, 0.8, 1, Sharpness, 0.9, 3),
    mk_op2(ShearY, 0.5, 8, TranslateY, 0.7, 9),
    mk_op2(AutoContrast, 0.5, 8, Equalize, 0.9, 2),
    
    #0_1
    mk_op2(Solarize, 0.4, 5, AutoContrast, 0.9, 3),
    mk_op2(TranslateY, 0.9, 9, TranslateY, 0.7, 9),
    mk_op2(AutoContrast, 0.9, 2, Solarize, 0.8, 3),
    mk_op2(Equalize, 0.8, 8, Invert, 0.1, 3),
    mk_op2(TranslateY, 0.7, 9, AutoContrast, 0.9, 1),
    
    #0_2
    mk_op2(Solarize, 0.4, 5, AutoContrast, 0.0, 2),
    mk_op2(TranslateY, 0.7, 9, TranslateY, 0.7, 9),
    mk_op2(AutoContrast, 0.9, 0, Solarize, 0.4, 3),
    mk_op2(Equalize, 0.7, 5, Invert, 0.1, 3),
    mk_op2(TranslateY, 0.7, 9, TranslateY, 0.7, 9),

    #0_3
    mk_op2(Solarize, 0.4, 5, AutoContrast, 0.9, 1),
    mk_op2(TranslateY, 0.8, 9, TranslateY, 0.9, 9),
    mk_op2(AutoContrast, 0.8, 0, TranslateY, 0.7, 9),
    mk_op2(TranslateY, 0.2, 7, Color, 0.9, 6),
    mk_op2(Equalize, 0.7, 6, Color, 0.4, 9),
    
    #1_0
    mk_op2(ShearY, 0.2, 7, Posterize, 0.3, 7),
    mk_op2(Color, 0.4, 3, Brightness, 0.6, 7),
    mk_op2(Sharpness, 0.3, 9, Brightness, 0.7, 9),
    mk_op2(Equalize, 0.6, 5, Equalize, 0.5, 1),
    mk_op2(Contrast, 0.6, 7, Sharpness, 0.6, 5),
    
    #1_1
    mk_op2(Brightness, 0.3, 7, AutoContrast, 0.5, 8),
    mk_op2(AutoContrast, 0.9, 4, AutoContrast, 0.5, 6),
    mk_op2(Solarize, 0.3, 5, Equalize, 0.6, 5),
    mk_op2(TranslateY, 0.2, 4, Sharpness, 0.3, 3),
    mk_op2(Brightness, 0.0, 8, Color, 0.8, 8),
    
    #1_2
    mk_op2(Solarize, 0.2, 6, Color, 0.8, 6),
    mk_op2(Solarize, 0.2, 6, AutoContrast, 0.8, 1),
    mk_op2(Solarize, 0.4, 1, Equalize, 0.6, 5),
    mk_op2(Brightness, 0.0, 0, Solarize, 0.5, 2),
    mk_op2(AutoContrast, 0.9, 5, Brightness, 0.5, 3),
    
    #1_3
    mk_op2(Contrast, 0.7, 5, Brightness, 0.0, 2),
    mk_op2(Solarize, 0.2, 8, Solarize, 0.1, 5),
    mk_op2(Contrast, 0.5, 1, TranslateY, 0.2, 9),
    mk_op2(AutoContrast, 0.6, 5, TranslateY, 0.0, 9),
    mk_op2(AutoContrast, 0.9, 4, Equalize, 0.8, 4),
    
    #1_4
    mk_op2(Brightness, 0.0, 7, Equalize, 0.4, 7),
    mk_op2(Solarize, 0.2, 5, Equalize, 0.7, 5),
    mk_op2(Equalize, 0.6, 8, Color, 0.6, 2),
    mk_op2(Color, 0.3, 7, Color, 0.2, 4),
    mk_op2(AutoContrast, 0.5, 2, Solarize, 0.7, 2),
    
    #1_5
    mk_op2(AutoContrast, 0.2, 0, Equalize, 0.1, 0),
    mk_op2(ShearY, 0.6, 5, Equalize, 0.6, 5),
    mk_op2(Brightness, 0.9, 3, AutoContrast, 0.4, 1),
    mk_op2(Equalize, 0.8, 8, Equalize, 0.7, 7),
    mk_op2(Equalize, 0.7, 7, Solarize, 0.5, 0),
    
    #1_6
    mk_op2(Equalize, 0.8, 4, TranslateY, 0.8, 9),
    mk_op2(TranslateY, 0.8, 9, TranslateY, 0.6, 9),
    mk_op2(TranslateY, 0.9, 0, TranslateY, 0.5, 9),
    mk_op2(AutoContrast, 0.5, 3, Solarize, 0.3, 4),
    mk_op2(Solarize, 0.5, 3, Equalize, 0.4, 4),
    
    #2_0
    mk_op2(Color, 0.7, 7, TranslateX, 0.5, 8),
    mk_op2(Equalize, 0.3, 7, AutoContrast, 0.4, 8),
    mk_op2(TranslateY, 0.4, 3, Sharpness, 0.2, 6),
    mk_op2(Brightness, 0.9, 6, Color, 0.2, 8),
    mk_op2(Solarize, 0.5, 2, Invert, 0.0, 3),
    
    #2_1
    mk_op2(AutoContrast, 0.1, 5, Brightness, 0.0, 0),
    mk_op2(Cutout, 0.2, 4, Equalize, 0.1, 1),
    mk_op2(Equalize, 0.7, 7, AutoContrast, 0.6, 4),
    mk_op2(Color, 0.1, 8, ShearY, 0.2, 3),
    mk_op2(ShearY, 0.4, 2, Rotate, 0.7, 0),
    
    #2_2
    mk_op2(ShearY, 0.1, 3, AutoContrast, 0.9, 5),
    mk_op2(TranslateY, 0.3, 6, Cutout, 0.3, 3),
    mk_op2(Equalize, 0.5, 0, Solarize, 0.6, 6),
    mk_op2(AutoContrast, 0.3, 5, Rotate, 0.2, 7),
    mk_op2(Equalize, 0.8, 2, Invert, 0.4, 0),
    
    #2_3
    mk_op2(Equalize, 0.9, 5, Color, 0.7, 0),
    mk_op2(Equalize, 0.1, 1, ShearY, 0.1, 3),
    mk_op2(AutoContrast, 0.7, 3, Equalize, 0.7, 0),
    mk_op2(Brightness, 0.5, 1, Contrast, 0.1, 7),
    mk_op2(Contrast, 0.1, 4, Solarize, 0.6, 5),
    
    #2_4
    mk_op2(Solarize, 0.2, 3, ShearX, 0.0, 0),
    mk_op2(TranslateX, 0.3, 0, TranslateX, 0.6, 0),
    mk_op2(Equalize, 0.5, 9, TranslateY, 0.6, 7),
    mk_op2(ShearX, 0.1, 0, Sharpness, 0.5, 1),
    mk_op2(Equalize, 0.8, 6, Invert, 0.3, 6),
    
    #2_5
    mk_op2(AutoContrast, 0.3, 9, Cutout, 0.5, 3),
    mk_op2(ShearX, 0.4, 4, AutoContrast, 0.9, 2),
    mk_op2(ShearX, 0.0, 3, Posterize, 0.0, 3),
    mk_op2(Solarize, 0.4, 3, Color, 0.2, 4),
    mk_op2(Equalize, 0.1, 4, Equalize, 0.7, 6),
    
    #2_6
    mk_op2(Equalize, 0.3, 8, AutoContrast, 0.4, 3),
    mk_op2(Solarize, 0.6, 4, AutoContrast, 0.7, 6),
    mk_op2(AutoContrast, 0.2, 9, Brightness, 0.4, 8),
    mk_op2(Equalize, 0.1, 0, Equalize, 0.0, 6),
    mk_op2(Equalize, 0.8, 4, Equalize, 0.0, 4),
    
    #2_7
    mk_op2(Equalize, 0.5, 5, AutoContrast, 0.1, 2),
    mk_op2(Solarize, 0.5, 5, AutoContrast, 0.9, 5),
    mk_op2(AutoContrast, 0.6, 1, AutoContrast, 0.7, 8),
    mk_op2(Equalize, 0.2, 0, AutoContrast, 0.1, 2),
    mk_op2(Equalize, 0.6, 9, Equalize, 0.4, 4)
]


aug = autoaugment(tform, X_train, y_train, batch_size) 

model = create_model(X_train, 10)

train_model(model, aug, 100)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

In [0]:
# Train some more!
train_model(model, aug, 50)