This notebook encodes the functions necessary to train, evaluate and test the models.

In [None]:
# First import the necessary libraries

import os
import random

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img


import itertools
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import roc_curve, auc, roc_auc_score, confusion_matrix, classification_report


In [None]:
# Display random OCT images of each class in the dataset

def display_image(train_dir):
    fig, ax = plt.subplots(1, 4, figsize=(15, 10))
    drusen = random.choice(os.listdir(train_dir + "DRUSEN"))
    drusen_image = load_img(train_dir + "DRUSEN/" + drusen)
    ax[0].imshow(drusen_image)
    ax[0].set_title("OCT with DRUSEN")
    ax[0].axis("Off")
    
    dme = random.choice(os.listdir(train_dir + "DME"))
    dme_image = load_img(train_dir + "DME/" + dme)
    ax[1].imshow(dme_image)
    ax[1].set_title("OCT with DME")
    ax[1].axis("Off")
    
    cnv = random.choice(os.listdir(train_dir + "CNV"))
    cnv_image = load_img(train_dir + "CNV/" + cnv)
    ax[2].imshow(cnv_image)
    ax[2].set_title("OCT with CNV")
    ax[2].axis("Off")
    
    normal = random.choice(os.listdir(train_dir + "NORMAL"))
    normal_image = load_img(train_dir + "NORMAL/" + normal)
    ax[3].imshow(normal_image)
    ax[3].set_title("NORMAL OCT")
    ax[3].axis("Off")
    plt.show()
    return

In [None]:
#construct the models using pretrained CNN
def generate_model(pretrained_cnn):
    INPUT_SHAPE = (150, 150, 3)
    if pretrained_cnn=='VGG19':
        base_model=tf.keras.applications.VGG19(include_top = False, weights = 'imagenet', input_tensor = None,input_shape = INPUT_SHAPE, pooling = None, classes = 100)
    if pretrained_cnn=='DenseNet201':
        base_model=tf.keras.applications.DenseNet201(include_top = False, weights = 'imagenet', input_tensor = None,input_shape = INPUT_SHAPE, pooling = None, classes = 100)    
    if pretrained_cnn=='ResNet101':
        base_model=tf.keras.applications.ResNet101V2(include_top = False, weights = 'imagenet', input_tensor = None,input_shape = INPUT_SHAPE, pooling = None, classes = 100)
    if pretrained_cnn=='MobileNetV2':
        base_model=tf.keras.applications.MobileNetV2(include_top = False, weights = 'imagenet', input_tensor = None,input_shape = INPUT_SHAPE, pooling = None, classes = 100)
    if pretrained_cnn=='Xception':
        base_model=tf.keras.applications.Xception(include_top = False, weights = 'imagenet', input_tensor = None,input_shape = INPUT_SHAPE, pooling = None, classes = 100)
    if pretrained_cnn=='InceptionV3':
        base_model=tf.keras.applications.InceptionV3(include_top = False, weights = 'imagenet', input_tensor = None,input_shape = INPUT_SHAPE, pooling = None, classes = 100)
    if pretrained_cnn=='EfficientNetB2':
        base_model=tf.keras.applications.EfficientNetB2(include_top = False, weights = 'imagenet', input_tensor = None,input_shape = INPUT_SHAPE, pooling = None, classes = 100)      
    base_model.trainable=False
    model = tf.keras.models.Sequential([
    base_model,
    tf.keras.layers.Conv2D(64, (3, 3), activation = 'relu'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(100, activation = 'relu'),
    tf.keras.layers.Dense(4, activation = 'softmax')])
    return model


In [None]:
#train the model
def train_model(model,train_generator,validation_generator,num_epochs, optimizer, metrics):
    model.compile(loss = 'categorical_crossentropy', optimizer =optimizer, metrics = metrics)
    checkpoint=tf.keras.callbacks.ModelCheckpoint((model.layers[0].name)+'_best_model.h5',monitor='val_accuracy',verbose=1,save_best_only=True,save_weights_only=False,mode='auto')
    callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=20,mode='auto',verbose=1)
    
    history = model.fit(train_generator,
                        steps_per_epoch = 100, 
                        epochs = num_epochs,
                        validation_data = validation_generator,
                        validation_steps = (32/16),
                        callbacks=[callback,checkpoint],
                        verbose = 1)
    return model,history
    

In [None]:
#plot the learning curves of the models
def plot_acc_loss(model,history):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    epochs = range(len(acc))
    plt.figure(figsize=(7,7))
    plt.plot(epochs, acc, 'r', label = 'Training accuracy')
    plt.plot(epochs, val_acc, 'b', label = 'Validation accuracy')
    plt.title('Training and validation accuracy')
    plt.legend()

    plt.figure(figsize = (7,7))
    plt.plot(epochs, loss, 'r', label = 'Training Loss')
    plt.plot(epochs, val_loss, 'b', label = 'Validation Loss')
    plt.title('Training and validation loss')
    plt.legend()
    plt.show()
    return 

In [None]:
#evaluate the model
def evaluate_model(model, history, test_generator):
    # Evaluate model
    score = model.evaluate(test_generator)
    print('\nTest set accuracy:', score[1], '\n')
    print('\nTest AUC:', score[2], '\n')
    print('\nTest Cohen Kappa:', score[3], '\n')
    print('\nTest F-Score:', np.mean(score[4]), '\n')
    print('\nTest Precision:', score[5], '\n')
    print('\nTest Recall:', score[6], '\n')
    plot_acc_loss(model,history)
    return

In [None]:
#plot confusion matrix for evaluation

def plot_confusion_matrix(cm, classes,normalize=False,title='Confusion matrix',cmap=plt.cm.Blues):

    plt.figure(figsize = (5,5))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    return



In [None]:
#plot ROC curve
def plot_roc_curves(y_true, y_pred, num_classes, class_labels):
    
    lb = LabelBinarizer()
    lb.fit(y_true)
    y_test = lb.transform(y_true)

    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for i in range(num_classes):
        fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_pred[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])

    # Plot all ROC curves
    for i in range(num_classes):
        fig, c_ax = plt.subplots(1,1, figsize = (6, 4))
        c_ax.plot(fpr[i], tpr[i],
                 label='ROC curve of class {0} (area = {1:0.4f})'
                 ''.format(class_labels[i], roc_auc[i]))
        c_ax.set_xlabel('False Positive Rate')
        c_ax.set_ylabel('True Positive Rate')
        c_ax.set_title('ROC curve of class {0}'.format(class_labels[i]))
        c_ax.legend(loc="lower right")
        plt.show()
    return roc_auc_score(y_test, y_pred)

In [None]:
#test the performance of the model
def test_model(model,test_generator):
    score = model.evaluate(test_generator, verbose=0)
    print('\nTest set accuracy:', score[1], '\n')
    y_pred=model.predict(test_generator, steps = int(968/44))
    y_true = np.array(test_generator.labels)
    y_pred_classes = np.argmax(y_pred,axis = 1)
    class_labels = list(test_generator.class_indices.keys())

    confusion_mtx = confusion_matrix(y_true, y_pred_classes)
   
    plot_confusion_matrix(confusion_mtx, classes = class_labels)
    #print('\n', sklearn.metrics.classification_report(y_true, y_pred_classes, target_names=class_labels), sep='')
    return 