In [None]:
import keras
import numpy as np
from keras.callbacks import ModelCheckpoint
from keras.datasets import cifar10
from keras.layers import Dense, Dropout, Flatten, Activation, Conv2D, MaxPooling2D
from keras.models import Sequential

In [None]:
# Constants
NUM_CLASSES = 4
CLASSES = [3,4,5,7]
IMG_ROWS, IMG_COLS = 32, 32
INPUT_SHAPE = (IMG_ROWS, IMG_COLS, 3)
BATCH_SIZE = 200
EPOCHS = 24
CHECKPOINTS_PATH = 'checkpoints/cifar/'
CHECKPOINTS_NAME = 'cifar4.hdf5'

In [None]:
# Load dataset function
def read_CIFAR(classes):
    '''
        Read and preprocess the dataset. 
        Extract the given classes images from labels (0 to 9) and return a one hot encoded version
    '''
    # load data with keras function
    (images_train, labels_train), (images_test, labels_test) = cifar10.load_data()
    # Image category selection
    images_train, labels_train = extract_classes(images_train, labels_train, classes)
    images_test, labels_test = extract_classes(images_test, labels_test, classes)
    # One hot encoding of labels
    labels_train = keras.utils.to_categorical(labels_train)
    labels_test = keras.utils.to_categorical(labels_test)

    return images_train, labels_train, images_test, labels_test

def extract_classes(images, labels, classes):
    '''
        Extract the given classes in images
    '''
    # extract the first class of images
    indices = (labels == classes[0]).reshape(labels.size)
    images_return = images[indices]
    labels_return = np.zeros(images_return.shape[0])
    
    # extend the images_return and lables_return with remaining classes
    for i in range(1,len(classes)):
        indices = (labels == classes[i]).reshape(labels.size)
        images_return = np.concatenate((images_return, images[indices]),axis=0)
        labels_return = np.concatenate((labels_return, np.ones(images_return.shape[0] - labels_return.shape[0]) * i))
    
    return images_return, labels_return

def use_checkpoints(path, file_name):
    if not os.path.isdir(path):
        os.makedir(path)
    return ModelCheckpoint(path + file_name, monitor='loss', verbose=1, save_best_only=True, mode='auto')

In [None]:
# Model definition
model = Sequential()

# first conv followed by max pooling
model.add(Conv2D(64, kernel_size=(3, 3), padding='same', activation='tanh', input_shape=INPUT_SHAPE))
model.add(MaxPooling2D(pool_size=(4, 4), padding='same'))

# second conv followed by max pooling
model.add(Conv2D(64, kernel_size=(3, 3), padding='same', activation='tanh'))
model.add(MaxPooling2D(pool_size=(4, 4), padding='same'))

# flatten the network and use Dense layers
model.add(Flatten())
model.add(Dropout(0.5))
model.add(Dense(1024, activation='sigmoid'))
model.add(Dropout(0.5))
model.add(Dense(NUM_CLASSES, activation='softmax'))

# model.load_weights("checkpoints/cifar10/dt_weights-improvement-04-0.44.hdf5")

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

In [None]:
# View model summary 
model.summary()

In [None]:
# Load data
images_train, labels_train, images_test, labels_test = read_CIFAR(CLASSES)

# Train model
model.fit(images_train,
          labels_train,
          batch_size=BATCH_SIZE,
          epochs=EPOCHS,
          shuffle=True,
          verbose=1,
          validation_split=0.33,
          callbacks=[use_checkpoints(CHECKPOINTS_PATH, CHECKPOINTS_NAME)])

In [None]:
# Test model
# Show some results of labeled test images
N = 10

images_plot = images_test[np.random.randint(0,images_test.size[0]-1,N)]
autoencoded_imgs = autoencoder.predict(images_plot)
plt.figure(figsize=(20, 4))
for i in range(N):
    ax = plt.subplot(2, N, i + 1)
    plt.imshow(images_plot[i])
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()