In [None]:
import sys
sys.path.append("..")
from train_by_reconnect.LaPerm import LaPermTrainLoop
from train_by_reconnect.weight_utils import random_prune
from train_by_reconnect.viz_utils import Profiler

In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.keras import optimizers
from tensorflow.keras import regularizers

In [None]:
def normalize(x_train, x_test):
    # normalize inputs for zero mean and unit variance
    mean, std = np.mean(x_train), np.std(x_train)
    X_train = (x_train-mean)/(std+1e-8)
    X_test = (x_test-mean)/(std+1e-8)
    return x_train, x_test

In [None]:
# training parameters
batch_size = 50
epochs = 125

learning_rate = 0.001 # initial learning rate
lr_drop = 10

tsize = 30000 # size of data for getting the train accuracy
vali_freq = 250 # validate per vali_freq batches

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train, x_test = normalize(x_train, x_test)

def lr_scheduler(epoch):
    learning_rate = 0.001
    return learning_rate * (0.6 ** (epoch // lr_drop))

def k_scheduler(epoch):
    return 1000

# data augmentation
datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True)
datagen.fit(x_train)

In [None]:
model = Sequential()
x_shape = (32, 32, 3)

initializer = 'he_uniform'

regularizer = regularizers.l2(1e-4)

model.add(Conv2D(64, (3, 3), padding='same',
                 input_shape=x_shape, kernel_regularizer=regularizer,
                 kernel_initializer=initializer, bias_initializer=initializer))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(Dropout(0.3))

model.add(Conv2D(64, (3, 3), padding='same', kernel_regularizer=regularizer,
                 kernel_initializer=initializer, bias_initializer=initializer))
model.add(Activation('relu'))
model.add(BatchNormalization())

model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizer,
                 kernel_initializer=initializer, bias_initializer=initializer))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(Dropout(0.4))

model.add(Conv2D(128, (3, 3), padding='same', kernel_regularizer=regularizer,
                 kernel_initializer=initializer, bias_initializer=initializer))
model.add(Activation('relu'))
model.add(BatchNormalization())

model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(256, kernel_regularizer=regularizer,
                kernel_initializer=initializer, bias_initializer=initializer))
model.add(Activation('relu'))
model.add(Dropout(0.5))

model.add(BatchNormalization())
model.add(Dense(256, kernel_regularizer=regularizer,
                kernel_initializer=initializer, bias_initializer=initializer))
model.add(Activation('relu'))
model.add(BatchNormalization())

model.add(Dropout(0.5))
model.add(Dense(10, kernel_initializer=initializer,
                bias_initializer=initializer))
model.add(Activation('softmax'))

# random_prune(model, prune_rate=0.7) # uncomment for random pruning

In [None]:
loop = LaPermTrainLoop(model=model,
                       loss='sparse_categorical_crossentropy',
                       inner_optimizer=tf.keras.optimizers.Adam(),
                       k_schedule=k_scheduler,
                       lr_schedule=lr_scheduler)
loop.fit(x_train, y_train,
         batch_size, epochs=epochs,
         datagen=datagen, 
         validation_data=(x_test, y_test), 
         validation_freq=vali_freq, 
         tsize=tsize)

In [None]:
Profiler(model)

In [None]:
# Visualize train and validation accuracies
import numpy as np
import matplotlib.pyplot as plt
plt.figure(figsize=(8, 6))
plt.plot(loop._history['val accuracy'], label='Validation Accuracy')
plt.plot(loop._history['accuracy'], label='Train Accuracy')
plt.grid(linestyle='--')
plt.xlabel('Epochs', size=15)
plt.ylabel('Accuracy', size=15)
plt.legend(prop={'size': 15})
plt.show()