In [1]:
import tensorflow as tf
import scipy.io 
import matplotlib.pyplot as plt
import cv2
import keras
from glob import glob
import numpy as np
from tqdm import tqdm
import os
from PIL import Image
import pandas as pd
import cv2

from sklearn.model_selection import KFold
# from keras.preprocessing.image import ImageDataGenerator

# import keras_metrics

from keras.applications import mobilenet, resnet50 #, vgg16, inception_v3, resnet50, 
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, History

from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array

import logging
# logging.getLogger().setLevel(logging.DEBUG)
import pickle


import seaborn
seaborn.set_style("darkgrid")

from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import precision_recall_curve

Using TensorFlow backend.


In [2]:
keras.__version__

'2.2.4'

## Params

In [3]:
# all_data_dir = 'E:\\Work/PathoBarIlan/Shlomi2018/'
all_data_dir = '/media/leetwito/DATA/Datasets/PathoBarIlan/Shlomi2018'

is_relative_path_csv = False
seed = 4221
k_idx = 3 # take one of the K-Folds
pretrained_model_path = "/media/leetwito/Windows/Users/leetw/PycharmProjects/PathoBarIlan/my_models/k=3, lr1e-2/model_spec_weights_epoch36-val_loss0.006-train_loss0.004-seed4221-i=3.hdf5"


pos_name_init = 'Cancer'
neg_name_init = 'Normal'

use_rgb = False # True=rgb, False=spectral
if use_rgb:
    file_ext = '.png'
else:
    file_ext = '.npy'
    
window_size = (200, 200)
shift = (100, 100)

In [4]:
w,h = window_size
if use_rgb:
    input_shape = (w,h,3)
else:
    input_shape = (w,h,40)
batch_size = 16

## utils

In [5]:
def read_slide(path):
    mat = scipy.io.loadmat(path)
    spectral = mat["Spec"]
    rgb = mat["Section"]
    shape = rgb.shape
    
    return spectral, rgb

In [6]:
def visualize_batch_of_crops(crops, n_iter_y, n_iter_x):
    fig, axes = plt.subplots(n_iter_y, n_iter_x, figsize=(5, 5), gridspec_kw = {'wspace':0, 'hspace':0})

    for i in range(n_iter_x):
        for j in range(n_iter_y):
            axes[j, i].imshow(crops[i*n_iter_y + j])
            axes[j, i].axis('off')
            axes[j, i].set_aspect('equal')
    plt.show()

In [7]:
def create_csv_for_folder(data_dir, ext):
    if ext[0] == '.':
        ext = ext[1:]
    data_df = pd.DataFrame(columns=['filename', 'label'])
    files = glob(os.path.join(data_dir,'*', '*.{}'.format(ext)))
    files = [file for file in files if "Mixed" not in file]
#     print(data_dir+'/*/*.{}'.format(ext))
    
    init_len = len(data_dir)
    delete_folder = all_data_dir
    if not is_relative_path_csv:
        delete_folder = '/'
    if not delete_folder[-1] == '/':
        delete_folder += '/'
    files = [file.replace(delete_folder, '/') for file in files]
#     print(files)
    labels = [1 if pos_name_init in file else 0 for file in files]
#     print(labels)
    data_df['filename'] = files
    data_df['label'] = labels
#     data_df.to_csv(os.path.join(data_dir, os.path.basename(data_dir)+'.csv'), index=False)
#     print('Created CSV successfully for folder {}'.format(data_dir))
    
    return data_df    

In [8]:
slides = glob(os.path.join(all_data_dir, "*/"))
slides

['/media/leetwito/DATA/Datasets/PathoBarIlan/Shlomi2018/Case10/',
 '/media/leetwito/DATA/Datasets/PathoBarIlan/Shlomi2018/Case11/',
 '/media/leetwito/DATA/Datasets/PathoBarIlan/Shlomi2018/Case12/',
 '/media/leetwito/DATA/Datasets/PathoBarIlan/Shlomi2018/Case14/',
 '/media/leetwito/DATA/Datasets/PathoBarIlan/Shlomi2018/Case16/',
 '/media/leetwito/DATA/Datasets/PathoBarIlan/Shlomi2018/Case16b/',
 '/media/leetwito/DATA/Datasets/PathoBarIlan/Shlomi2018/Case17/',
 '/media/leetwito/DATA/Datasets/PathoBarIlan/Shlomi2018/Case18/',
 '/media/leetwito/DATA/Datasets/PathoBarIlan/Shlomi2018/Case19484/',
 '/media/leetwito/DATA/Datasets/PathoBarIlan/Shlomi2018/Case8/']

In [9]:
skf = KFold(n_splits=5, shuffle=True, random_state=seed)

train_slides_all = []
test_slides_all = []
val_slides_all = []

for train_index, test_index in skf.split(np.arange(len(slides)).T, np.arange(len(slides)).T):
    print("TRAIN:", train_index, "TEST:", test_index)
    train_slides_all.append(train_index)
    val_slides_all.append([test_index[0]])
    test_slides_all.append([test_index[1]])

TRAIN: [2 3 4 5 6 7 8 9] TEST: [0 1]
TRAIN: [0 1 2 3 5 7 8 9] TEST: [4 6]
TRAIN: [0 1 3 4 5 6 7 8] TEST: [2 9]
TRAIN: [0 1 2 4 6 7 8 9] TEST: [3 5]
TRAIN: [0 1 2 3 4 5 6 9] TEST: [7 8]


In [10]:
train_index = train_slides_all[k_idx]
val_index = val_slides_all[k_idx]
test_index = test_slides_all[k_idx]

train_index, val_index, test_index

(array([0, 1, 2, 4, 6, 7, 8, 9]), [3], [5])

In [11]:
def get_dfs_for_indices(slides, index_list):
    dfs = []
    for slide in np.array(slides)[index_list]:
        data_dir = slide
        dfs.append(create_csv_for_folder(data_dir, file_ext))
    df = pd.concat(dfs, ignore_index=True)
    df = df.sample(frac=1, random_state=seed)  # frac=1 is same as shuffling df.
    return df

In [12]:
df_train = get_dfs_for_indices(slides, train_index)
df_test = get_dfs_for_indices(slides, test_index)
df_val = get_dfs_for_indices(slides, val_index)

In [13]:
pd.options.display.max_colwidth = 150

In [14]:
assert len(set(df_train.label.values)) == 2 and len(set(df_val.label.values)) == 2 and len(set(df_test.label.values)) == 2  

In [15]:
def batch_norm(x):
#     print("x.shape:", x.shape)
    maxi = x.max(axis=1).max(axis=1)
#     print("maxi.shape:", maxi.shape)
    maxi = np.repeat(maxi[:, np.newaxis, : ], window_size[0], axis=1)
    maxi = np.repeat(maxi[:, np.newaxis, : ], window_size[1], axis=1)
    return x/maxi


def generator_from_df(df, batch_size, shuffle=True): 
    
    n_batches = df.shape[0]//batch_size
    while True:
        if shuffle:
            df_tmp = df.copy().sample(frac=1)  # frac=1 is same as shuffling df.
        else:
            df_tmp = df
        
        for i in range(n_batches):
            sub = df_tmp.iloc[batch_size*i:batch_size*(i+1)]
            if use_rgb:
                X = [img_to_array(load_img(f, target_size=input_shape)) for f in sub.filename]
            else:
                X = [np.load(f) for f in sub.filename]
                
            X = batch_norm(np.stack(X))
            logging.debug("from file {}\nto file {}".format(sub.iloc[0].filename, sub.iloc[-1].filename))

            Y = sub.label.values
            Y = to_categorical(Y, num_classes=2)
            # Simple model, one input, one output.
            
            yield X, Y

In [23]:
def plot_roc_curve(y_true, y_scores, figsize=(15, 8)):
    fpr, tpr, threshold = roc_curve(y_true, y_scores)
    plt.figure(figsize=figsize)
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], "k--")
    plt.axes()

    return fpr, tpr, threshold

In [16]:
train_generator = generator_from_df(df_train, batch_size)
val_generator = generator_from_df(df_val, batch_size)
test_generator = generator_from_df(df_test, batch_size, shuffle=False)

In [17]:
loaded_model = keras.models.load_model(pretrained_model_path)

In [18]:
generator = train_generator
df = df_train

data_len = len(df)//batch_size
y_proba = []
y_gt = []
loss = []
for i in tqdm(range(data_len)):
    x, y = next(generator)
    y_gt.append(y)
    y_proba.append(loaded_model.predict(x))
    loss.append(loaded_model.evaluate(x, y))
with open(pretrained_model_path.replace('.hdf5', '-[y_gt,y_proba,loss].pkl'), 'wb') as f:
    pickle.dump([y_gt, y_proba, loss], f)

  0%|          | 0/187 [00:00<?, ?it/s]



  1%|          | 1/187 [00:04<15:22,  4.96s/it]



  1%|          | 2/187 [00:08<13:56,  4.52s/it]



  2%|▏         | 3/187 [00:11<12:56,  4.22s/it]



  2%|▏         | 4/187 [00:15<12:11,  3.99s/it]



  3%|▎         | 5/187 [00:18<11:34,  3.82s/it]



  3%|▎         | 6/187 [00:22<11:14,  3.73s/it]



  4%|▎         | 7/187 [00:25<11:03,  3.69s/it]



  4%|▍         | 8/187 [00:29<10:53,  3.65s/it]



  5%|▍         | 9/187 [00:32<10:39,  3.59s/it]



  5%|▌         | 10/187 [00:36<10:38,  3.61s/it]



  6%|▌         | 11/187 [00:40<10:34,  3.60s/it]



  6%|▋         | 12/187 [00:43<10:27,  3.59s/it]



  7%|▋         | 13/187 [00:47<10:20,  3.57s/it]



  7%|▋         | 14/187 [00:50<10:22,  3.60s/it]



  8%|▊         | 15/187 [00:54<10:18,  3.60s/it]



  9%|▊         | 16/187 [00:58<10:19,  3.62s/it]



  9%|▉         | 17/187 [01:01<10:20,  3.65s/it]



 10%|▉         | 18/187 [01:05<10:06,  3.59s/it]



 10%|█         | 19/187 [01:08<10:01,  3.58s/it]



 11%|█         | 20/187 [01:12<09:55,  3.57s/it]



 11%|█         | 21/187 [01:16<09:52,  3.57s/it]



 12%|█▏        | 22/187 [01:19<09:46,  3.55s/it]



 12%|█▏        | 23/187 [01:23<09:49,  3.59s/it]



 13%|█▎        | 24/187 [01:27<09:56,  3.66s/it]



 13%|█▎        | 25/187 [01:30<09:49,  3.64s/it]



 14%|█▍        | 26/187 [01:34<09:57,  3.71s/it]



 14%|█▍        | 27/187 [01:38<09:51,  3.70s/it]



 15%|█▍        | 28/187 [01:42<09:53,  3.73s/it]



 16%|█▌        | 29/187 [01:45<09:44,  3.70s/it]



 16%|█▌        | 30/187 [01:49<09:31,  3.64s/it]



 17%|█▋        | 31/187 [01:52<09:30,  3.66s/it]



 17%|█▋        | 32/187 [01:56<09:31,  3.69s/it]



 18%|█▊        | 33/187 [02:00<09:27,  3.69s/it]



 18%|█▊        | 34/187 [02:04<09:26,  3.70s/it]



 19%|█▊        | 35/187 [02:07<09:18,  3.67s/it]



 19%|█▉        | 36/187 [02:11<09:13,  3.67s/it]



 20%|█▉        | 37/187 [02:14<09:08,  3.66s/it]



 20%|██        | 38/187 [02:18<09:08,  3.68s/it]



 21%|██        | 39/187 [02:22<09:08,  3.70s/it]



 21%|██▏       | 40/187 [02:26<09:07,  3.73s/it]



 22%|██▏       | 41/187 [02:30<09:13,  3.79s/it]



 22%|██▏       | 42/187 [02:33<09:06,  3.77s/it]



 23%|██▎       | 43/187 [02:37<09:02,  3.77s/it]



 24%|██▎       | 44/187 [02:41<09:00,  3.78s/it]



 24%|██▍       | 45/187 [02:45<08:55,  3.77s/it]



 25%|██▍       | 46/187 [02:48<08:47,  3.74s/it]



 25%|██▌       | 47/187 [02:52<08:40,  3.72s/it]



 26%|██▌       | 48/187 [02:56<08:40,  3.75s/it]



 26%|██▌       | 49/187 [03:00<08:33,  3.72s/it]



 27%|██▋       | 50/187 [03:03<08:30,  3.73s/it]



 27%|██▋       | 51/187 [03:07<08:22,  3.70s/it]



 28%|██▊       | 52/187 [03:10<08:15,  3.67s/it]



 28%|██▊       | 53/187 [03:14<08:19,  3.73s/it]



 29%|██▉       | 54/187 [03:18<08:10,  3.68s/it]



 29%|██▉       | 55/187 [03:22<08:07,  3.69s/it]



 30%|██▉       | 56/187 [03:25<08:06,  3.71s/it]



 30%|███       | 57/187 [03:29<08:03,  3.72s/it]



 31%|███       | 58/187 [03:33<08:00,  3.73s/it]



 32%|███▏      | 59/187 [03:37<07:53,  3.70s/it]



 32%|███▏      | 60/187 [03:40<07:51,  3.71s/it]



 33%|███▎      | 61/187 [03:44<07:50,  3.74s/it]



 33%|███▎      | 62/187 [03:48<07:42,  3.70s/it]



 34%|███▎      | 63/187 [03:51<07:35,  3.67s/it]



 34%|███▍      | 64/187 [03:55<07:36,  3.71s/it]



 35%|███▍      | 65/187 [03:59<07:39,  3.76s/it]



 35%|███▌      | 66/187 [04:03<07:40,  3.80s/it]



 36%|███▌      | 67/187 [04:07<07:40,  3.84s/it]



 36%|███▋      | 68/187 [04:11<07:36,  3.83s/it]



 37%|███▋      | 69/187 [04:14<07:33,  3.84s/it]



 37%|███▋      | 70/187 [04:18<07:25,  3.81s/it]



 38%|███▊      | 71/187 [04:22<07:16,  3.76s/it]



 39%|███▊      | 72/187 [04:26<07:16,  3.79s/it]



 39%|███▉      | 73/187 [04:30<07:12,  3.79s/it]



 40%|███▉      | 74/187 [04:33<07:07,  3.78s/it]



 40%|████      | 75/187 [04:37<07:02,  3.77s/it]



 41%|████      | 76/187 [04:41<07:02,  3.81s/it]



 41%|████      | 77/187 [04:45<06:58,  3.80s/it]



 42%|████▏     | 78/187 [04:48<06:52,  3.78s/it]



 42%|████▏     | 79/187 [04:52<06:49,  3.79s/it]



 43%|████▎     | 80/187 [04:56<06:45,  3.79s/it]



 43%|████▎     | 81/187 [05:00<06:41,  3.79s/it]



 44%|████▍     | 82/187 [05:04<06:41,  3.83s/it]



 44%|████▍     | 83/187 [05:08<06:40,  3.85s/it]



 45%|████▍     | 84/187 [05:11<06:34,  3.83s/it]



 45%|████▌     | 85/187 [05:15<06:29,  3.82s/it]



 46%|████▌     | 86/187 [05:19<06:23,  3.80s/it]



 47%|████▋     | 87/187 [05:23<06:19,  3.79s/it]



 47%|████▋     | 88/187 [05:26<06:13,  3.77s/it]



 48%|████▊     | 89/187 [05:30<06:08,  3.77s/it]



 48%|████▊     | 90/187 [05:34<06:03,  3.75s/it]



 49%|████▊     | 91/187 [05:38<06:01,  3.77s/it]



 49%|████▉     | 92/187 [05:42<05:58,  3.78s/it]



 50%|████▉     | 93/187 [05:45<05:54,  3.77s/it]



 50%|█████     | 94/187 [05:49<05:47,  3.73s/it]



 51%|█████     | 95/187 [05:53<05:43,  3.73s/it]



 51%|█████▏    | 96/187 [05:56<05:37,  3.71s/it]



 52%|█████▏    | 97/187 [06:00<05:33,  3.71s/it]



 52%|█████▏    | 98/187 [06:04<05:31,  3.73s/it]



 53%|█████▎    | 99/187 [06:07<05:24,  3.69s/it]



 53%|█████▎    | 100/187 [06:11<05:17,  3.64s/it]



 54%|█████▍    | 101/187 [06:15<05:15,  3.67s/it]



 55%|█████▍    | 102/187 [06:18<05:09,  3.64s/it]



 55%|█████▌    | 103/187 [06:22<05:05,  3.63s/it]



 56%|█████▌    | 104/187 [06:26<05:02,  3.64s/it]



 56%|█████▌    | 105/187 [06:29<04:59,  3.65s/it]



 57%|█████▋    | 106/187 [06:33<04:56,  3.67s/it]



 57%|█████▋    | 107/187 [06:36<04:51,  3.64s/it]



 58%|█████▊    | 108/187 [06:40<04:47,  3.64s/it]



 58%|█████▊    | 109/187 [06:44<04:43,  3.64s/it]



 59%|█████▉    | 110/187 [06:47<04:39,  3.63s/it]



 59%|█████▉    | 111/187 [06:51<04:33,  3.59s/it]



 60%|█████▉    | 112/187 [06:55<04:31,  3.61s/it]



 60%|██████    | 113/187 [06:58<04:27,  3.61s/it]



 61%|██████    | 114/187 [07:02<04:25,  3.63s/it]



 61%|██████▏   | 115/187 [07:05<04:22,  3.65s/it]



 62%|██████▏   | 116/187 [07:09<04:21,  3.68s/it]



 63%|██████▎   | 117/187 [07:13<04:17,  3.68s/it]



 63%|██████▎   | 118/187 [07:16<04:10,  3.62s/it]



 64%|██████▎   | 119/187 [07:20<04:09,  3.67s/it]



 64%|██████▍   | 120/187 [07:24<04:11,  3.76s/it]



 65%|██████▍   | 121/187 [07:28<04:08,  3.76s/it]



 65%|██████▌   | 122/187 [07:32<04:01,  3.72s/it]



 66%|██████▌   | 123/187 [07:35<03:57,  3.71s/it]



 66%|██████▋   | 124/187 [07:39<03:54,  3.72s/it]



 67%|██████▋   | 125/187 [07:43<03:46,  3.66s/it]



 67%|██████▋   | 126/187 [07:46<03:43,  3.67s/it]



 68%|██████▊   | 127/187 [07:50<03:41,  3.70s/it]



 68%|██████▊   | 128/187 [07:54<03:40,  3.73s/it]



 69%|██████▉   | 129/187 [07:57<03:33,  3.69s/it]



 70%|██████▉   | 130/187 [08:01<03:29,  3.68s/it]



 70%|███████   | 131/187 [08:05<03:24,  3.66s/it]



 71%|███████   | 132/187 [08:08<03:21,  3.66s/it]



 71%|███████   | 133/187 [08:12<03:18,  3.67s/it]



 72%|███████▏  | 134/187 [08:16<03:12,  3.63s/it]



 72%|███████▏  | 135/187 [08:19<03:08,  3.62s/it]



 73%|███████▎  | 136/187 [08:23<03:04,  3.61s/it]



 73%|███████▎  | 137/187 [08:26<03:00,  3.60s/it]



 74%|███████▍  | 138/187 [08:30<02:56,  3.60s/it]



 74%|███████▍  | 139/187 [08:33<02:52,  3.58s/it]



 75%|███████▍  | 140/187 [08:37<02:48,  3.58s/it]



 75%|███████▌  | 141/187 [08:41<02:44,  3.58s/it]



 76%|███████▌  | 142/187 [08:44<02:41,  3.58s/it]



 76%|███████▋  | 143/187 [08:48<02:38,  3.61s/it]



 77%|███████▋  | 144/187 [08:51<02:34,  3.59s/it]



 78%|███████▊  | 145/187 [08:55<02:30,  3.59s/it]



 78%|███████▊  | 146/187 [08:59<02:28,  3.61s/it]



 79%|███████▊  | 147/187 [09:02<02:24,  3.60s/it]



 79%|███████▉  | 148/187 [09:06<02:20,  3.60s/it]



 80%|███████▉  | 149/187 [09:09<02:16,  3.60s/it]



 80%|████████  | 150/187 [09:13<02:14,  3.62s/it]



 81%|████████  | 151/187 [09:17<02:11,  3.65s/it]



 81%|████████▏ | 152/187 [09:21<02:08,  3.66s/it]



 82%|████████▏ | 153/187 [09:24<02:04,  3.66s/it]



 82%|████████▏ | 154/187 [09:28<02:00,  3.64s/it]



 83%|████████▎ | 155/187 [09:31<01:56,  3.65s/it]



 83%|████████▎ | 156/187 [09:35<01:53,  3.66s/it]



 84%|████████▍ | 157/187 [09:39<01:50,  3.69s/it]



 84%|████████▍ | 158/187 [09:43<01:46,  3.68s/it]



 85%|████████▌ | 159/187 [09:46<01:42,  3.67s/it]



 86%|████████▌ | 160/187 [09:50<01:38,  3.66s/it]



 86%|████████▌ | 161/187 [09:53<01:35,  3.66s/it]



 87%|████████▋ | 162/187 [09:57<01:31,  3.65s/it]



 87%|████████▋ | 163/187 [10:01<01:27,  3.66s/it]



 88%|████████▊ | 164/187 [10:05<01:25,  3.70s/it]



 88%|████████▊ | 165/187 [10:08<01:21,  3.71s/it]



 89%|████████▉ | 166/187 [10:12<01:17,  3.71s/it]



 89%|████████▉ | 167/187 [10:16<01:14,  3.71s/it]



 90%|████████▉ | 168/187 [10:19<01:10,  3.70s/it]



 90%|█████████ | 169/187 [10:23<01:06,  3.68s/it]



 91%|█████████ | 170/187 [10:27<01:02,  3.66s/it]



 91%|█████████▏| 171/187 [10:30<00:59,  3.69s/it]



 92%|█████████▏| 172/187 [10:34<00:54,  3.66s/it]



 93%|█████████▎| 173/187 [10:38<00:51,  3.64s/it]



 93%|█████████▎| 174/187 [10:41<00:47,  3.66s/it]



 94%|█████████▎| 175/187 [10:45<00:43,  3.62s/it]



 94%|█████████▍| 176/187 [10:48<00:39,  3.58s/it]



 95%|█████████▍| 177/187 [10:52<00:36,  3.61s/it]



 95%|█████████▌| 178/187 [10:56<00:32,  3.63s/it]



 96%|█████████▌| 179/187 [10:59<00:29,  3.63s/it]



 96%|█████████▋| 180/187 [11:03<00:25,  3.64s/it]



 97%|█████████▋| 181/187 [11:07<00:21,  3.63s/it]



 97%|█████████▋| 182/187 [11:10<00:18,  3.63s/it]



 98%|█████████▊| 183/187 [11:14<00:14,  3.61s/it]



 98%|█████████▊| 184/187 [11:17<00:10,  3.58s/it]



 99%|█████████▉| 185/187 [11:21<00:07,  3.62s/it]



 99%|█████████▉| 186/187 [11:25<00:03,  3.64s/it]



100%|██████████| 187/187 [11:28<00:00,  3.61s/it]


In [19]:
y_proba = np.concatenate(y_proba)
y_gt = np.concatenate(y_gt)

In [20]:
results = y_proba.argmax(axis=1)==y_gt.argmax(axis=1)

In [21]:
print("correct: {}/{}".format(results.sum(), len(results)))

correct: 2992/2992


In [25]:
y_pred = np.argmax(y_proba, axis=1)
plot_roc_curve(y_gt, y_pred)
plt.show()

ValueError: multilabel-indicator format is not supported

In [27]:
y_gt.argmax(axis=1)

array([1, 0, 0, ..., 0, 0, 1])

In [31]:
y_proba.shape

(2992, 2)

In [36]:
scores = [y_proba[idx, pred] for idx, pred in enumerate(y_gt.argmax(axis=1))]

In [37]:
low_score_pred = sorted(scores)

In [43]:
low_score_pred[:15]

[0.81256026,
 0.8556575,
 0.85801196,
 0.86075246,
 0.8644029,
 0.86545503,
 0.8771054,
 0.8848943,
 0.88885045,
 0.893748,
 0.8950029,
 0.90370256,
 0.9097792,
 0.92767656,
 0.9307996]

In [None]:
y_pred = np.argmax(y_pred, axis=1)
y_pred.shape

In [None]:
y_pred.sum()

In [None]:
y_test.sum()

In [None]:
(y_pred==y_test).sum()/220

### plot roc curve

In [None]:
def plot_roc_curve(y_true, y_scores, figsize=(15, 8)):
    fpr, tpr, threshold = roc_curve(y_true, y_scores)
    plt.figure(figsize=figsize)
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], "k--")
    plt.axes()

    return fpr, tpr, threshold

In [None]:
def plot_precision_recall_curve(y_true, y_scores, figsize=(15, 8)):
    precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
    plt.figure(figsize=figsize)
    plt.plot(precision, recall)
#     plt.plot([0, 1], [0, 1], "k--")
    plt.axes()
    
    return precision, recall, thresholds

In [None]:
# mobilenet_model.load_weights("my_models/model_spec_weights_epoch36-val_loss0.006-train_loss0.004.hdf5")
mobilenet_model.load_weights("my_models/model_spec_weights_epoch14-val_loss0.019-seed4221-k_idx=1.hdf5")

In [None]:
def get_y_test_and_pred():
    y_tests = []
    y_preds = []
    for i in range(len(df_test)//batch_size):
        print(i, "out of", len(df_test)//batch_size)
        x, y = next(test_generator)
        print(mobilenet_model.evaluate(x, y))
        y_tests.append(y.argmax(axis=1))
        y_preds.append(mobilenet_model.predict(x)[:,1])
    
    y_test = np.stack(y_tests)
    y_pred = np.stack(y_preds)

    y_test = y_test.reshape((-1, 1))
    y_pred = y_pred.reshape((-1, 1))
    
    return y_test, y_pred

In [None]:
y_test, y_pred = get_y_test_and_pred()

In [None]:
print(roc_auc_score(y_test, y_pred))
fpr, tpr, threshold = plot_roc_curve(y_test, y_pred, figsize=(15, 8))

In [None]:
# precision, recall, threshold = plot_precision_recall_curve(y_test, y_pred, figsize=(15, 8)) # kills the kernel for me.