In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf
import splitfolders
import sys
import pathlib
import pandas as pd
import pickle

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import SGD
import tensorflow_datasets as tfds

from sklearn.metrics import roc_curve, auc, roc_auc_score
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import KFold, StratifiedKFold

from helper_functions import print_sens_spec_3class, print_sens_spec_2class

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [None]:
# Define parameters
batch_size = 32
img_height = 256
img_width = 256
num_splits = 5
binary_threshold = 0.5
epochs = 35
learning_rate=0.001

class_names=['0-Normal','1-Non-CI-DME','2-CI-DME']
num_classes = len(class_names)

## Read Data and Train Model

In [None]:
trainval_dir = "Data-New/Primary-3-class-train-val"

trainval_dir = pathlib.Path(trainval_dir)

trainval_count = len(list(trainval_dir.glob('*/*.png')))
print("Training & Val Count: " + str(trainval_count))

In [None]:
trainval_ds = tf.keras.utils.image_dataset_from_directory(
    trainval_dir,
    seed=124,
    color_mode='grayscale',
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode='int',
    class_names=class_names
)

In [None]:
# Convert dataset to numpy for use with kfold
# Convert datasets into Numpy - can process for kfold
trainval_ds_ub = trainval_ds.unbatch()
trainval_ds_np = tfds.as_numpy(trainval_ds_ub)

trainval_data = []
trainval_labels = []

for entry in trainval_ds_np:
    trainval_data.append(entry[0])
    trainval_labels.append(entry[1])
    
trainval_data_np = np.array(trainval_data)
trainval_labels_np = np.array(trainval_labels)

print(trainval_data_np.shape)
print(trainval_labels_np.shape)

In [None]:
labelscount = [0, 0, 0]
for i in trainval_labels_np:
    if i == 0: 
        labelscount[0] = labelscount[0]+1
    if i == 1: 
        labelscount[1] = labelscount[1]+1
    if i == 2: 
        labelscount[2] = labelscount[2]+1
        

In [None]:
# Define kfold split
kfold = StratifiedKFold(n_splits=num_splits, shuffle=True)

In [None]:
# Data augmentation (if required)
with tf.device('/CPU:0'):

    data_augmentation = keras.Sequential(
      [
        layers.RandomFlip("horizontal",input_shape=(img_height,img_width,1),seed=136),
        layers.RandomRotation(0.1, seed=175),
        layers.RandomZoom(0.1, seed=181),
      ]
    )

In [None]:
# Define Model
def getModel():
    
    model = Sequential([
        data_augmentation,
        layers.Rescaling(1./255, input_shape=(img_height, img_width, 1)),
        layers.Conv2D(16, 3, padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(32, 3, padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(64, 3, padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(128, 3, padding='same', activation='relu'),
        layers.MaxPooling2D(),
        layers.Conv2D(256, 3, padding='same', activation='relu'),
        layers.BatchNormalization(),
        layers.MaxPooling2D(),
        layers.Dense(256, activation='relu'),
        layers.Dropout(0.2),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(num_classes, activation='softmax')
    ])
    
    return model

In [None]:
# Comvert labels to categorical for use in model
trainval_labels_categorical = tf.keras.utils.to_categorical(trainval_labels_np, num_classes=3)

In [None]:
%%time
tf.keras.backend.clear_session()

loss_values = []
acc_scores = []
auc_scores = []
prec_scores = []
recall_scores = []

histories = []
models = []

val_predictions_values_cv = []
val_predictions_cv = []
val_labels_values_cv = []
val_labels_cv = []

i = 1

# Comvert labels to categorical for use in model as per appraoch used previously
trainval_labels_categorical = tf.keras.utils.to_categorical(trainval_labels_np, num_classes=3)

for train_index, val_index in kfold.split(trainval_data_np, trainval_labels_np):
    
    tf.keras.backend.clear_session()

    print("\n#### This is split: " + str(i) + " ####")
    
    train_data_fold, val_data_fold = trainval_data_np[train_index], trainval_data_np[val_index]
    train_labels_fold, val_labels_fold = trainval_labels_categorical[train_index], trainval_labels_categorical[val_index]
    
    print("Training Entries: " + str(train_data_fold.shape[0]))
    print("Validation Entries: " + str(val_data_fold.shape[0]))
    print("\n")
    
    model = getModel()
    
    opt = keras.optimizers.Adam(learning_rate=learning_rate)
    
    model.compile(optimizer=opt,
              loss=tf.keras.losses.CategoricalCrossentropy(),
              metrics=[tf.keras.metrics.CategoricalAccuracy(),
                       tf.keras.metrics.AUC(),
                       tf.keras.metrics.Precision(),
                       tf.keras.metrics.Recall()])

    history = model.fit(train_data_fold,train_labels_fold,
                        epochs=epochs,
                        batch_size = batch_size, 
                        validation_data = (val_data_fold, val_labels_fold))
    
    histories.append(history)

    # Could also take score from val history results
    scores = model.evaluate(val_data_fold, val_labels_fold)
    
    v_predictions_values = model.predict(val_data_fold)
    v_predictions = np.argmax(v_predictions_values, axis=1)
    v_labels_values = val_labels_fold
    v_labels = np.argmax(v_labels_values, axis=1)
    
    val_predictions_values_cv.append(v_predictions_values)
    val_predictions_cv.append(v_predictions)
    val_labels_values_cv.append(v_labels_values)
    val_labels_cv.append(v_labels)
    

    print("%s: %.2f%% \n" % (model.metrics_names[1], scores[1]*100))

    loss_values.append(scores[0])
    acc_scores.append(scores[1])
    auc_scores.append(scores[2])
    prec_scores.append(scores[3])
    recall_scores.append(scores[4])
    i += 1

print("\n#### OVERALL SUMMARY - FROM EVALUATION SETS ####")
print("Average final loss: %.2f (+/- %.2f)" % (np.mean(loss_values), np.std(loss_values)))
print("Average accuracy: %.2f%% (+/- %.2f%%)" % (np.mean(acc_scores)*100, np.std(acc_scores)*100))
print("Average final AUC: %.2f (+/- %.2f)" % (np.mean(auc_scores), np.std(auc_scores)))
print("Average precision: %.2f%% (+/- %.2f%%)" % (np.mean(prec_scores)*100, np.std(prec_scores)*100))
print("Average recall: %.2f%% (+/- %.2f%%)" % (np.mean(recall_scores)*100, np.std(recall_scores)*100))

In [None]:
# Record to check epoch for minimum loss and max accuracy
loss_val_values = np.zeros(epochs)
acc_val_values = np.zeros(epochs)

for history in histories:

    # Training and Validation Accuracy and Loss
    train_acc = history.history['categorical_accuracy']
    val_acc = history.history['val_categorical_accuracy']
    
    acc_val_values = np.add(acc_val_values, val_acc)

    train_loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    loss_val_values = np.add(loss_val_values, val_loss)

    train_prec = history.history['precision']
    val_prec = history.history['val_precision']

    train_recall = history.history['recall']
    val_recall = history.history['val_recall']

    train_auc = history.history['auc']
    val_auc = history.history['val_auc']

    epochs_range = range(epochs)

    plt.figure(figsize=(20, 8))
    plt.subplot(1, 5, 1)
    plt.plot(train_acc, label='Training Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 5, 2)
    plt.plot(train_loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')

    plt.subplot(1, 5, 3)
    plt.plot(train_prec, label='Precision')
    plt.plot(val_prec, label='Validation Precision')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Precision')

    plt.subplot(1, 5, 4)
    plt.plot(train_recall, label='Recall')
    plt.plot(val_recall, label='Validation Recall')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Recall')

    plt.subplot(1, 5, 5)
    plt.plot(train_auc, label='AUC')
    plt.plot(val_auc, label='Validation AUC')
    plt.legend(loc='lower right')
    plt.title('Training and Validation AUC')

    plt.show()

In [None]:
avg_loss_val_values = loss_val_values/num_splits
min_avg_loss = min(avg_loss_val_values)
min_loss_epoch = (np.argmin(avg_loss_val_values) + 1)

print("Average minimum loss: " + str(min_avg_loss))
print("Epoch number for min loss: " + str(min_loss_epoch))

epochs_plot = range(1,epochs+1)

plt.figure()
plt.plot(epochs_plot,avg_loss_val_values,)
plt.xlabel('Epoch')
plt.ylabel('Avg Loss')
plt.title('Loss over epochs')

In [None]:
avg_acc_val_values = acc_val_values/num_splits
max_avg_acc = max(avg_acc_val_values)
max_acc_epoch = (np.argmax(avg_acc_val_values) + 1)

print("Average maximum accuracy: " + str(max_avg_acc))
print("Epoch number for max acc: " + str(max_acc_epoch))

epochs_plot = range(1,epochs+1)

plt.figure()
plt.plot(epochs_plot,avg_acc_val_values,)
plt.xlabel('Epoch')
plt.ylabel('Avg Accuracy')
plt.title('Accuracy over epochs')

## Metrics on the Cross Validation Sets

In [None]:
for i in range(num_splits):

    val_labels = val_labels_cv[i]
    val_labels_values = val_labels_values_cv[i]
    val_predictions = val_predictions_cv[i]
    val_predictions_values = val_predictions_values_cv[i]
    
    print("\n### Results for split number: " + str(i+1) + "\n")

    print(classification_report(val_labels, val_predictions, target_names=['M0','M1','M2']))

    cm = confusion_matrix(val_labels, val_predictions, labels=[0,1,2])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm,display_labels=['M0','M1','M2'])
    disp.plot(cmap=plt.cm.Blues)
    
    print_sens_spec_3class(cm)
    
    fpr_val = dict()
    tpr_val = dict()
    roc_auc_val = dict()

    plt.figure()

    for i in range(num_classes):
        fpr_val[i], tpr_val[i], _ = roc_curve(val_labels_values[:, i], val_predictions_values[:, i])
        roc_auc_val[i] = auc(fpr_val[i], tpr_val[i])
        plt.plot(fpr_val[i], tpr_val[i], label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc_val[i]))


    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic - Validation Data')
    plt.legend(loc="lower right")
    plt.show() 
    

## Retrain the dataset on all of the data

In [None]:
trainval_ds = tf.keras.utils.image_dataset_from_directory(
    trainval_dir,
    seed=124,
    color_mode='grayscale',
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode='categorical',
    class_names=class_names
)

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
trainval_ds_p = trainval_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)

In [None]:
tf.keras.backend.clear_session()

full_model = getModel()
    
opt = keras.optimizers.Adam(learning_rate=0.001)

full_model.compile(optimizer=opt,
          loss=tf.keras.losses.CategoricalCrossentropy(),
          metrics=[tf.keras.metrics.CategoricalAccuracy(),
                   tf.keras.metrics.AUC(),
                   tf.keras.metrics.Precision(),
                   tf.keras.metrics.Recall()])

history = full_model.fit(trainval_ds_p,
                         epochs=epochs,
                         batch_size = batch_size)

In [None]:
# Save the model
# full_model.save('Models/DME-classification-model-centre.h5')

In [None]:
# Load the model 
# full_model = keras.models.load_model("Saved_Models/model_name.h5")

full_model = keras.models.load_model("Models/DME-classification-model-centre.h5")

In [None]:
full_model.summary()

## Check Test Dataset - First Ophthalmologist

In [None]:
test_dir_first = "Data-New/Test-3-class-1st-ophth-adjust"

test_ds_first = tf.keras.utils.image_dataset_from_directory(
    test_dir_first,
    shuffle=False,
    color_mode='grayscale',
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode='categorical',
    class_names=class_names
)

In [None]:
full_model.evaluate(test_ds_first, batch_size=batch_size)

In [None]:
# Get predictions and get evaluation
test_predictions_values_first = full_model.predict(test_ds_first)
test_predictions_first = np.argmax(test_predictions_values_first, axis=1)
test_labels_values_first = np.concatenate(list(test_ds_first.map(lambda x, y:y)))
test_labels_first = np.argmax(test_labels_values_first, axis=1)

print(classification_report(test_labels_first, test_predictions_first, target_names=class_names))

In [None]:
disp_labels = ['NORMAL', 'NCI-DME', 'CI-DME']

In [None]:
cm_first = confusion_matrix(test_labels_first, test_predictions_first, labels=[0,1,2])
disp = ConfusionMatrixDisplay(confusion_matrix=cm_first,display_labels=disp_labels)

disp.plot(cmap=plt.cm.Blues)
disp.ax_.set_title('1st Ophthalmologist')

In [None]:
disp.figure_.savefig('Figures/cm-3class-test-first-ophth.pdf', bbox_inches = 'tight')

In [None]:
print_sens_spec_3class(cm_first)

In [None]:
# Examine predictions for the validation data ~ print percentages for each entry and consider effect of model
test_filenames_first = [l[-10:-4] for l in test_ds_first.file_paths]

In [None]:
df_test_first = pd.DataFrame(test_filenames_first, columns=['Filename'])
df_test_first['NORMAL-Prob'] = test_predictions_values_first[:,0]
df_test_first['NCI-DME-Prob'] = test_predictions_values_first[:,1]
df_test_first['CI-DME-Prob'] = test_predictions_values_first[:,2]
df_test_first['Predicted'] = test_predictions_first
df_test_first['Actual'] = test_labels_first
df_test_first['NORMAL-Label'] = test_labels_values_first[:,0]
df_test_first['NCI-DME-Label'] = test_labels_values_first[:,1]
df_test_first['CI-DME-Label'] = test_labels_values_first[:,2]
class_codes = {0:'NORMAL', 1:'NCI-DME', 2:'CI-DME'}
df_test_first['Predicted.name'] = df_test_first['Predicted'].map(class_codes)
df_test_first['Actual.name'] = df_test_first['Actual'].map(class_codes)
df_test_first['Correct'] = df_test_first['Predicted'] == df_test_first['Actual']

In [None]:
pd.set_option("display.max_rows", None)
pd.options.display.float_format = '{:,.4f}'.format
display(df_test_first)

In [None]:
df_test_first.to_csv('predictions/predictions_classification_test_first.csv')

In [None]:
# Compute ROC curve and ROC area for each class
fpr_test_first = dict()
tpr_test_first = dict()
roc_auc_test_first = dict()

plt.figure()

# for i in range(num_classes):
    
i=0
fpr_test_first[i], tpr_test_first[i], _ = roc_curve(test_labels_values_first[:, i], test_predictions_values_first[:, i])
roc_auc_test_first[i] = auc(fpr_test_first[i], tpr_test_first[i])
plt.plot(fpr_test_first[i], tpr_test_first[i], label='ROC curve of NORMAL class (area = {1:0.2f})'.format(i, roc_auc_test_first[i]))

i=1
fpr_test_first[i], tpr_test_first[i], _ = roc_curve(test_labels_values_first[:, i], test_predictions_values_first[:, i])
roc_auc_test_first[i] = auc(fpr_test_first[i], tpr_test_first[i])
plt.plot(fpr_test_first[i], tpr_test_first[i], label='ROC curve of NCI-DME class (area = {1:0.2f})'.format(i, roc_auc_test_first[i]))

i=2
fpr_test_first[i], tpr_test_first[i], _ = roc_curve(test_labels_values_first[:, i], test_predictions_values_first[:, i])
roc_auc_test_first[i] = auc(fpr_test_first[i], tpr_test_first[i])
plt.plot(fpr_test_first[i], tpr_test_first[i], label='ROC curve of CI-DME class (area = {1:0.2f})'.format(i, roc_auc_test_first[i]))


plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('1st Ophthalmologist')
plt.legend(loc="lower right")
plt.savefig('figures/ROC-3class-test-first-ophth.pdf', bbox_inches = 'tight')
plt.show() 

## Check Test Dataset - Second Ophthalmologist

In [None]:
test_dir_second = "Data-New/Test-3-class-2nd-ophth-adjust"

test_ds_second = tf.keras.utils.image_dataset_from_directory(
    test_dir_second,
    shuffle=False,
    color_mode='grayscale',
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode='categorical',
    class_names=class_names
)

In [None]:
full_model.evaluate(test_ds_second, batch_size=batch_size)

In [None]:
# Get predictions and get evaluation
test_predictions_values_second = full_model.predict(test_ds_second)
test_predictions_second = np.argmax(test_predictions_values_second, axis=1)
test_labels_values_second = np.concatenate(list(test_ds_second.map(lambda x, y:y)))
test_labels_second = np.argmax(test_labels_values_second, axis=1)

print(classification_report(test_labels_second, test_predictions_second, target_names=class_names))

In [None]:
cm_second = confusion_matrix(test_labels_second, test_predictions_second, labels=[0,1,2])
disp = ConfusionMatrixDisplay(confusion_matrix=cm_second,display_labels=disp_labels)

disp.plot(cmap=plt.cm.Blues)
disp.ax_.set_title('2nd Ophthalmologist')

In [None]:
disp.figure_.savefig('figures/cm-3class-test-second-ophth.pdf', bbox_inches = 'tight')

In [None]:
print_sens_spec_3class(cm_second)

In [None]:
# Examine predictions for the validation data ~ print percentages for each entry and consider effect of model
test_filenames_second = [l[-10:-4] for l in test_ds_second.file_paths]

In [None]:
df_test_second = pd.DataFrame(test_filenames_second, columns=['Filename'])
df_test_second['NORMAL-Prob'] = test_predictions_values_second[:,0]
df_test_second['NCI-DME-Prob'] = test_predictions_values_second[:,1]
df_test_second['CI-DME-Prob'] = test_predictions_values_second[:,2]
df_test_second['Predicted'] = test_predictions_second
df_test_second['Actual'] = test_labels_second
df_test_second['NORMAL-Label'] = test_labels_values_second[:,0]
df_test_second['NCI-DME-Label'] = test_labels_values_second[:,1]
df_test_second['CI-DME-Label'] = test_labels_values_second[:,2]
class_codes = {0:'NORMAL', 1:'NCI-DME', 2:'CI-DME'}
df_test_second['Predicted.name'] = df_test_second['Predicted'].map(class_codes)
df_test_second['Actual.name'] = df_test_second['Actual'].map(class_codes)
df_test_second['Correct'] = df_test_second['Predicted'] == df_test_second['Actual']

In [None]:
pd.set_option("display.max_rows", None)
pd.options.display.float_format = '{:,.4f}'.format
display(df_test_second)

In [None]:
df_test_second.to_csv('predictions/predictions_classification_test_second.csv')

In [None]:
# Compute ROC curve and ROC area for each class
fpr_test_second = dict()
tpr_test_second = dict()
roc_auc_test_second = dict()

plt.figure()

# for i in range(num_classes):
    
i=0
fpr_test_second[i], tpr_test_second[i], _ = roc_curve(test_labels_values_second[:, i], test_predictions_values_second[:, i])
roc_auc_test_second[i] = auc(fpr_test_second[i], tpr_test_second[i])
plt.plot(fpr_test_second[i], tpr_test_second[i], label='ROC curve of NORMAL class (area = {1:0.2f})'.format(i, roc_auc_test_second[i]))

i=1
fpr_test_second[i], tpr_test_second[i], _ = roc_curve(test_labels_values_second[:, i], test_predictions_values_second[:, i])
roc_auc_test_second[i] = auc(fpr_test_second[i], tpr_test_second[i])
plt.plot(fpr_test_second[i], tpr_test_second[i], label='ROC curve of NCI-DME class (area = {1:0.2f})'.format(i, roc_auc_test_second[i]))

i=2
fpr_test_second[i], tpr_test_second[i], _ = roc_curve(test_labels_values_second[:, i], test_predictions_values_second[:, i])
roc_auc_test_second[i] = auc(fpr_test_second[i], tpr_test_second[i])
plt.plot(fpr_test_second[i], tpr_test_second[i], label='ROC curve of CI-DME class (area = {1:0.2f})'.format(i, roc_auc_test_second[i]))


plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('2nd Ophthalmologist')
plt.legend(loc="lower right")
plt.savefig('figures/ROC-3class-test-second-ophth.pdf', bbox_inches = 'tight')
plt.show()