In [1]:
import tensorflow as tf
import numpy as np
import skimage.transform as skimage_transform
import time, datetime
import matplotlib.pyplot as plt
import pickle
from sklearn.model_selection import train_test_split
from keras.utils import np_utils

# constants definition
CNN_NAME = 'pretrained_vgg16'
EPOCHS = 200
NUM_CLASSES = 10
BATCH_SIZE_TRAIN = 32
LEARNING_RATE = 1e-4


# build the neural net model
def build_model():
  base_model = tf.keras.applications.vgg16.VGG16(
      weights='imagenet', include_top=False, input_shape=(48, 48, 3))
  # last layer from third block of vgg16
  last_base_model = base_model.get_layer('block3_pool').output
  # add classification layers on top of it
  layer = tf.keras.layers.GlobalAveragePooling2D()(last_base_model)
  layer = tf.keras.layers.BatchNormalization()(layer)
  layer = tf.keras.layers.Dense(256, activation='relu')(layer)
  layer = tf.keras.layers.Dense(256, activation='relu')(layer)
  layer = tf.keras.layers.Dropout(0.6)(layer)
  output = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(layer)

  # freeze base model layers, so to reduce the number of trainable params
  for base_model_layer in base_model.layers:
     base_model_layer.trainable = False

  return tf.keras.Model(base_model.input, output)

In [2]:
# load cifar10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# split training dataset in train and validation
x_train, x_val, y_train, y_val = train_test_split(
    x_train, y_train, test_size=0.15, stratify=np.array(y_train), random_state=42)

# transform labels into one-hot encoded vector
y_train = np_utils.to_categorical(y_train, NUM_CLASSES)
y_val = np_utils.to_categorical(y_val, NUM_CLASSES)
y_test = np_utils.to_categorical(y_test, NUM_CLASSES)

# resize images to fit in 48x48 VGG16 minimum image size
x_train = np.array([skimage_transform.resize(image, (48, 48)) for image in x_train])
x_val = np.array([skimage_transform.resize(image, (48, 48)) for image in x_val])
x_test = np.array([skimage_transform.resize(image, (48, 48)) for image in x_test])

model = build_model()
#model.summary()

# set optimizer and compile model
opt_adam = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(loss='binary_crossentropy', optimizer=opt_adam, metrics=['accuracy'])

In [3]:
# prepare data augmentation configuration
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255, horizontal_flip=False)
train_datagen.fit(x_train)
train_generator = train_datagen.flow(x_train, y_train, batch_size=BATCH_SIZE_TRAIN)

val_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1. / 255, horizontal_flip=False)
val_datagen.fit(x_val)
val_generator = val_datagen.flow(x_val, y_val, batch_size=BATCH_SIZE_TRAIN)

In [None]:
# training
start = time.time()

print("\tStart training [", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()), "]\n")
train_history = model.fit(
    train_generator,
    steps_per_epoch=x_train.shape[0] // BATCH_SIZE_TRAIN,
    validation_data=val_generator,
    validation_steps=x_val.shape[0] // BATCH_SIZE_TRAIN,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE_TRAIN,
    verbose=1)
print("\n\tEnd training [", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()), "]")

end = time.time()
print("\n\tTotal training time:", datetime.timedelta(seconds=round(end - start, 0)))

In [None]:
# testing
scores = model.evaluate(x_test, y_test, verbose=1)
print('\nTest result: %.3f loss: %.3f' % (scores[1]*100, scores[0]))

In [None]:
# save to disk
with open(f"{CNN_NAME}_{EPOCHS}_model.json", 'w') as model_file:
  model_file.write(model.to_json())
model.save_weights(f"{CNN_NAME}_{EPOCHS}_weights.h5")

with open(f"{CNN_NAME}_{EPOCHS}_history.sav", 'wb') as history_file:
  pickle.dump(train_history.history, history_file)

In [None]:
# load history
history = pickle.load(open(f"{CNN_NAME}_{EPOCHS}_history.sav", "rb"))

plt.figure(figsize=(15.0, 9.0))
plt.xlabel('Epoch')

#plt.plot(history['loss'])
#plt.plot(history['val_loss'])
#plt.title('Custom CNN loss')
#plt.ylabel('Loss')
#plt.legend(['Train', 'Validation'], loc='upper center')

plt.plot(history['accuracy'])
plt.plot(history['val_accuracy'])
plt.title('Custom CNN accuracy')
plt.ylabel('Accuracy')
plt.legend(['Train', 'Validation'], loc='lower center')

plt.show()