In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import sys
sys.path.append("..")

from Code.ResidualAttentionNetwork import ResidualAttentionNetwork

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import random
import os
from PIL import Image

import h5py

import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf

from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

from keras import optimizers

from keras.models import load_model

# Image Generators

In [None]:
# For data, download from Kaggle: https://www.kaggle.com/paultimothymooney/kermany2018

train_dir = "/pylon5/cc5614p/deopha32/eye_images/train"
test_dir = "/pylon5/cc5614p/deopha32/eye_images/test"

In [None]:
train_datagen = ImageDataGenerator(
    rotation_range=15,
    rescale=1./255,
    shear_range=0.1,
    zoom_range=0.2,
    horizontal_flip=True,
    width_shift_range=0.1,
    height_shift_range=0.1,
    validation_split=0.33
)

test_datagen = ImageDataGenerator(
    rescale=1./255
)

In [None]:
train_generator = train_datagen.flow_from_directory(
    directory=train_dir, 
    shuffle=True,
    target_size=IMAGE_SIZE,
    class_mode='categorical',
    color_mode='grayscale',
    batch_size=batch_size,
    subset="training"
)

valid_generator = train_datagen.flow_from_directory(
    directory=train_dir, 
    shuffle=True,
    target_size=IMAGE_SIZE,
    class_mode="categorical",
    color_mode='grayscale',
    batch_size=batch_size,
    subset="validation"
)

test_generator = test_datagen.flow_from_directory(
    directory=test_dir,
    shuffle=False,
    target_size=IMAGE_SIZE,
    class_mode=None,
    batch_size=batch_size,
)

# Network Metadata 

In [None]:
IMAGE_WIDTH=32
IMAGE_HEIGHT=32
IMAGE_SIZE=(IMAGE_WIDTH, IMAGE_HEIGHT)
IMAGE_CHANNELS=1
IMAGE_SHAPE=(IMAGE_WIDTH, IMAGE_HEIGHT, IMAGE_CHANNELS)
batch_size=32

epochs = 500

num_classes = 4

In [None]:
STEP_SIZE_TRAIN=train_generator.n//train_generator.batch_size
STEP_SIZE_VALID=valid_generator.n//valid_generator.batch_size
STEP_SIZE_TEST=test_generator.n//test_generator.batch_size

In [None]:
model_path = "../Saved_Model/eye-disorder-model-{epoch:02d}-{val_acc:.2f}.h5"

# early_stop = EarlyStopping(monitor='val_acc',  verbose=1, patience=50)
checkpoint = ModelCheckpoint(model_path, monitor='val_acc', verbose=1, save_best_only=True)
csv_logger = CSVLogger("../Saved_Model/model_history.csv", append=True)

callbacks = [checkpoint, csv_logger]

# Initial Model Training

In [None]:
with tf.device('/gpu:0'):
    model = ResidualAttentionNetwork(
                input_shape=IMAGE_SHAPE, 
                n_classes=num_classes, 
                activation='softmax').build_model()

    model.compile(optimizer=optimizers.Adam(lr=0.0001),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    history = model.fit_generator(generator=train_generator,
                        steps_per_epoch=STEP_SIZE_TRAIN, verbose=1, callbacks=callbacks,
                        validation_data=valid_generator, validation_steps=STEP_SIZE_VALID,
                        epochs=epochs, use_multiprocessing=True, workers=40)

# Load Model and Resume Training

In [None]:
# with tf.device('/gpu:0'):
#     model = load_model(model_path)

#     history = model.fit_generator(generator=train_generator,
#                         steps_per_epoch=STEP_SIZE_TRAIN, verbose=1, callbacks=callbacks,
#                         validation_data=valid_generator, validation_steps=STEP_SIZE_VALID,
#                         epochs=epochs, use_multiprocessing=True, workers=40, initial_epoch=50)

# model.save(model_path)

# Visualize Data 

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12))
ax1.plot(history.history['loss'], color='b', label="Training loss")
ax1.plot(history.history['val_loss'], color='r', label="validation loss")
ax1.set_xticks(np.arange(1, 10, 1))
ax1.set_yticks(np.arange(0, 1, 0.1))

ax2.plot(history.history['acc'], color='b', label="Training accuracy")
ax2.plot(history.history['val_acc'], color='r',label="Validation accuracy")
ax2.set_xticks(np.arange(1, 10, 1))

legend = plt.legend(loc='best', shadow=True)
plt.tight_layout()
plt.show()

# Evaluate Model on Validation Data

In [None]:
loss, accuracy = model.evaluate_generator(valid_generator, STEP_SIZE_VALID, verbose=1, use_multiprocessing=True, workers=40)
print("Test: accuracy = %f  ;  loss = %f " % (accuracy, loss))