In [2]:
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import numpy as np
import network
from data_pipeline import set_up_data_pipeline, decode, augment
from params import patches_root, train_db_path, test_db_path
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, Callback

In [None]:
img_height = 256
img_width = 256
batch_size = 64

In [None]:
# Load the training and validation datasets

train_generator = ImageDataGenerator(preprocessing_function=None,
    rescale=1./255, horizontal_flip=True, vertical_flip=True)

validation_generator = ImageDataGenerator(preprocessing_function=None,
    rescale=1./255)

train_data_gen  = train_generator.flow_from_directory(
    directory=r"./patches/train/",
    target_size=(img_width, img_height), color_mode='grayscale',
    batch_size=batch_size, class_mode="categorical", shuffle=True)

validation_data_gen = validation_generator.flow_from_directory(
    directory=r"./patches/test/",
    target_size=(img_width, img_height), color_mode='grayscale',
    batch_size=batch_size, class_mode="categorical", shuffle=True)

In [None]:
sample_training_images, _ = next(train_data_gen)

In [None]:
# This function will plot images in the form of a grid with 1 row and 5 columns where images are placed in each column.
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip(images_arr, axes):
        ax.imshow(img[:,:,0], cmap='gray')
        ax.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
plotImages(sample_training_images[:5])

In [None]:
print("[*] Define model")
model = network.build()

In [None]:
sgd = tf.optimizers.SGD(lr=0.001, momentum=0.9, decay=0.0005)
model.compile(
    optimizer=sgd, 
    loss='categorical_crossentropy', 
    metrics=['accuracy'])

In [None]:
#           ------------ Train the Model ------------
ConstrainLayer = network.ConstrainLayer(model)
callbacks = [ModelCheckpoint('./saved_model/weights.{epoch:02d}.h5',
    monitor='acc',verbose=1, save_best_only=False,
    save_freq=1), ConstrainLayer]

history = model.fit_generator(generator=train_data_gen, epochs=45, 
     callbacks=callbacks)  #removed validation data

In [None]:
#           ------------------------------------------

# Next thing that could be implemented is the 'Extremely Randomized 
# Trees Classifier" used to extract the 1x200 vector which contains
# the Deep Convolutional Features. It is reported that this provides
# a small improvement in the model's accuracy.

# The following line makes sure that the program exits successfully every 
# time. It deals with the "Exception ignored in BaseSession" bug. Not of 
# importance anyway.
import gc; gc.collect()