## Plot ROCs

In [None]:
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input as MobileNet_preprocess_input
from keras.applications.vgg19 import VGG19
from keras.applications.vgg19 import preprocess_input as VGG19_preprocess_input
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_resnet_v2 import preprocess_input as InceptionResNetV2_preprocess_input
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input as InceptionV3_preprocess_input
from keras.applications.mobilenet_v2 import MobileNetV2
from keras.applications.mobilenet_v2 import preprocess_input as MobileNetV2_preprocess_input
from keras.applications.nasnet import NASNetLarge
from keras.applications.nasnet import preprocess_input as NASNetLarge_preprocess_input

%matplotlib inline
import matplotlib.pyplot as plt

import data_preparation
import params
import os
import reset
import gradient_accumulation
from utils import plot_train_metrics, save_model
from sklearn.metrics import roc_curve, auc
from train import create_data_generator, _create_base_model, create_simple_model, create_attention_model

metadata = data_preparation.load_metadata()
metadata, labels = data_preparation.preprocess_metadata(metadata)
train, valid = data_preparation.stratify_train_test_split(metadata)

# for these image sizes, we don't need gradient_accumulation to achieve BATCH_SIZE = 256
optimizer = 'adam'
if params.DEFAULT_OPTIMIZER != optimizer:
    optimizer = gradient_accumulation.AdamAccumulate(
        lr=params.LEARNING_RATE, accum_iters=params.ACCUMULATION_STEPS)

base_models = [
        [MobileNet, params.MOBILENET_IMG_SIZE, MobileNet_preprocess_input],
        [InceptionResNetV2, params.INCEPTIONRESNETV2_IMG_SIZE,
         InceptionResNetV2_preprocess_input],
        [VGG19, params.VGG19_IMG_SIZE, VGG19_preprocess_input],
        [InceptionV3, params.INCEPTIONV3_IMG_SIZE, InceptionV3_preprocess_input],
        [MobileNetV2, params.MOBILENETV2_IMG_SIZE, MobileNetV2_preprocess_input],
        [NASNetLarge, params.NASNETLARGE_IMG_SIZE, NASNetLarge_preprocess_input],
    ]
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


def plot_model_ROC(_Model, input_shape, preprocessing_function,
                train, valid, labels,
                extend_model_callback, optimizer, name_prefix):
   
    test_X, test_Y = next(create_data_generator(
        valid, labels, 10000, preprocessing_function, target_size=input_shape))
    
    
    baseModel = _create_base_model(_Model,
                                   labels,
                                   test_X.shape[1:],
                                   trainable=False,
                                   weights=None)

    model = extend_model_callback(baseModel, labels, optimizer)
    
    model_name  =name_prefix+'_' + baseModel.name
    
    weights = os.path.join(params.RESULTS_FOLDER, model_name, 'weights.best.hdf5')
    
    print('Loading '+weights)
    model.load_weights(weights, by_name=True)
    model.trainable = False
    
    pred_Y = model.predict(test_X, batch_size = 32, verbose = True)
    
    fig, c_ax = plt.subplots(1,1, figsize = (9, 9))
    for (idx, c_label) in enumerate(labels):
        fpr, tpr, thresholds = roc_curve(test_Y[:,idx].astype(int), pred_Y[:,idx])
        c_ax.plot(fpr, tpr, label = '%s (AUC:%0.2f)'  % (c_label, auc(fpr, tpr)))
    c_ax.legend()
    c_ax.set_title(model_name+' ROC Curve')
    c_ax.set_xlabel('False Positive Rate')
    c_ax.set_ylabel('True Positive Rate')
    
    ROC_image_file_path = os.path.join(params.RESULTS_FOLDER, model_name, model_name + '_ROC.png')
    
    fig.savefig(ROC_image_file_path)
    print('Saved ROC plot at'+ROC_image_file_path)


for [_Model, input_shape, preprocess_input] in base_models:
    plot_model_ROC(_Model, input_shape, preprocess_input,
                train, valid, labels,
                create_simple_model, optimizer, 'simple')
    
for [_Model, input_shape, preprocess_input] in base_models:
    plot_model_ROC(_Model, input_shape, preprocess_input,
                train, valid, labels,
                create_attention_model, optimizer, 'attention')