In [None]:
#=========================================================================
# Trains VGG16 DNN on 2 classes from CIFAR-10 dataset
# and stores h5 model in models directory
#=========================================================================
from __future__ import print_function
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model, Sequential
from keras.layers import Input, Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.callbacks import ModelCheckpoint
import os
import datetime
import numpy as np
import matplotlib.pyplot as plt

from keras.applications.vgg16 import VGG16

#=========================================================================
def filterDataByClass(x_data, y_data, class_array):
    ix = np.isin(y_data, class_array)
    ixArry = np.where(ix)
    indexes = ixArry[0] # list of indexes that have specified classes
    x_data = x_data[indexes]
    y_data = y_data[indexes]
    return x_data, y_data

#=========================================================================

print(datetime.datetime.now())

# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# set parameters-------------------------------
# Only train and test on the specified classes
classes = [0,1]
baseModelName = 'vgg16'
datasetName = 'cifar10'
batch_size = 32
epochs = 50
#----------------------------------------------

classesName = ''.join(str(x) for x in classes)
num_classes = len(classes)

model_name = baseModelName + '_' + datasetName + '_' + str(num_classes) + 'classes' + classesName + '.h5'
modelLocation = 'models/%s'%(model_name)
print('model will be stored in %s'%(modelLocation))
x_train, y_train = filterDataByClass(x_train, y_train, classes)
x_test, y_test = filterDataByClass(x_test, y_test, classes)

# convert training data to zero's and 1's as 'to categorical' needs this
i = 0;
for c in classes:
    y_train[y_train == c] = i
    y_test[y_test == c] = i
    i += 1;

# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# model starts here--------------
vggmodel = VGG16(
    weights=None, 
    include_top=True, 
    classes=num_classes,
    input_shape=(32,32,3)
)

vggmodel.summary()

model = vggmodel

# initiate RMSprop optimizer
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)  # original

# same optimiser as used in the paper
model.compile(
    loss='categorical_crossentropy',
    optimizer='sgd',
    metrics=['accuracy'])

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

checkpoint = ModelCheckpoint(
   model_name, 
   monitor='val_acc', 
   verbose=0, 
   save_best_only=True, 
   save_weights_only=False,
   mode='auto')

history = model.fit(
   x=x_train,
   y=y_train,
   validation_split=0.1,
   batch_size=batch_size,
   epochs=epochs,
   callbacks=[checkpoint],
   verbose=1)

# Save model and weights
model.save(modelLocation)
print('Saved trained model at %s ' %(modelLocation))

# Score trained model.
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
print(datetime.datetime.now())

# Plot training & validation accuracy values
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

Using TensorFlow backend.


2020-03-06 10:25:45.470654
x_train shape: (50000, 32, 32, 3)
50000 train samples
10000 test samples
model will be stored in models/vgg16_cifar10_2classes01.h5


In [3]:
classes = [0,1]
classesName = ''.join(str(x) for x in classes)
print(classesName)

01
