In [None]:
import sys
sys.path.append("..")
from train_by_reconnect.LaPerm import LaPermTrainLoop
from train_by_reconnect.weight_utils import random_prune
from train_by_reconnect.viz_utils import Profiler

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D, BatchNormalization, Activation
from tensorflow.keras.layers import AveragePooling2D, Input, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import cifar10

In [None]:
# training hyperparameters
batch_size = 50  # orig paper trained all networks with batch_size=128
epochs = 200

tsize = 30000 # size of data for getting the train accuracy
vali_freq = 800 # validate per vali_freq batches

learning_rate = 0.001

# Load the CIFAR10 data.
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Input image dimensions.
input_shape = x_train.shape[1:]

# Normalize data.
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# If subtract pixel mean is enabled
x_train_mean = np.mean(x_train, axis=0)
x_train -= x_train_mean
x_test -= x_train_mean

print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print('y_train shape:', y_train.shape)

# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)


In [None]:
n = 8
depth = n * 6 + 2

In [None]:
# Modified based on https://keras.io/examples/cifar10_resnet/

initializer = 'he_uniform'
regularizer = l2(1e-4)


def resnet_layer(inputs,
                 num_filters=16,
                 kernel_size=3,
                 strides=1,
                 activation='relu',
                 batch_normalization=True):

    conv = Conv2D(num_filters,
                  kernel_size=kernel_size,
                  strides=strides,
                  padding='same',
                  kernel_initializer=initializer,
                  kernel_regularizer=regularizer)
    x = inputs
    x = conv(x)
    if batch_normalization:
        x = BatchNormalization()(x)
    if activation is not None:
        x = Activation(activation)(x)
    return x


def resnet(input_shape, depth, num_classes=10):
    assert (depth - 2) % 6 == 0, "incorrect depth."

    num_filters = 16
    num_res_blocks = int((depth - 2) / 6)

    inputs = Input(shape=input_shape)
    x = resnet_layer(inputs=inputs)

    for stack in range(3):
        for res_block in range(num_res_blocks):
            strides = 1
            if stack > 0 and res_block == 0:
                strides = 2
            y = resnet_layer(inputs=x,
                             num_filters=num_filters,
                             strides=strides)
            y = resnet_layer(inputs=y,
                             num_filters=num_filters,
                             activation=None)
            if stack > 0 and res_block == 0:
                x = resnet_layer(inputs=x,
                                 num_filters=num_filters,
                                 kernel_size=1,
                                 strides=strides,
                                 activation=None,
                                 batch_normalization=False)
            x = tf.keras.layers.add([x, y])
            x = Activation('relu')(x)
        num_filters *= 2

    x = AveragePooling2D(pool_size=8)(x)
    y = Flatten()(x)
    outputs = Dense(num_classes,
                    activation='softmax',
                    kernel_initializer=initializer)(y)

    return Model(inputs=inputs, outputs=outputs)

In [None]:
datagen = ImageDataGenerator(
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True)

# Compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(x_train)

In [None]:
def lr_scheduler(epoch):
    lr = learning_rate
    if epoch > 180:
        lr *= 0.5e-3
    elif epoch > 160:
        lr *= 1e-3
    elif epoch > 120:
        lr *= 1e-2
    elif epoch > 80:
        lr *= 1e-1
    return lr

def k_scheduler(epoch):
    return 800

In [None]:
model = resnet(input_shape=input_shape, depth=depth)
# random_prune(model, prune_rate=0.7) # uncomment for random pruning

In [None]:
y_test.shape, y_train.shape

In [None]:
epochs = 200
loop = LaPermTrainLoop(model=model,
                       loss='categorical_crossentropy',
                       inner_optimizer=Adam(),
                       k_schedule=k_scheduler, 
                       lr_schedule=lr_scheduler)

loop.fit(x_train, y_train, batch_size, epochs=epochs,
         datagen=datagen, validation_data=(x_test, y_test),
         validation_freq=vali_freq, 
         tsize=tsize)

In [None]:
Profiler(model)

In [None]:
# Visualize train and validation accuracies
import numpy as np
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 6))
plt.plot(loop._history['val accuracy'], label='Validation Accuracy')
plt.plot(loop._history['accuracy'], label='Train Accuracy')
plt.grid(linestyle='--')
plt.xlabel('Epochs', size=15)
plt.ylabel('Accuracy', size=15)
plt.legend(prop={'size':15})
plt.show()