In [1]:
import os
import numpy as np
import pandas as pd
import json
from tqdm import tqdm_notebook as tqdm

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
from sklearn.metrics import roc_curve, auc
from matplotlib import pyplot as plt

from utils import load_config
from models.net_keras import *
from data_loader.data_loader import get_patches, DataGenerator

Using TensorFlow backend.


In [2]:
exp_name = 'seed_1'
weights_path = os.path.join('../models/', exp_name, 'weights.h5')
json_path = os.path.join('../log/', exp_name + '_info.json')

In [3]:
# Load the training file for patch-based, patient-based threshold, and patch size
with open(json_path , 'r') as reader:
    jf = json.loads(reader.read())
patch_threshold = jf['patch_threshold']
patient_threshold = jf['patient_threshold']
config = jf['config']
patch_size = config['dataset']['input_dim'][0]

In [4]:
# Env settings
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
gpu_options = tf.GPUOptions(allow_growth=True)
sess_config = tf.ConfigProto(gpu_options=gpu_options)
set_session(tf.Session(config=sess_config))

In [5]:
# Load the model
model = eval('simple_cnn_sigmoid')(config['dataset']['input_dim'])
model.load_weights(weights_path, by_name=True)
model.compile(
    loss=keras.losses.binary_crossentropy,
    optimizer=keras.optimizers.Adam(amsgrad=True),
    metrics=['accuracy'])

W0210 13:42:01.948721 139697577998080 deprecation_wrapper.py:119] From /opt/python-3.6-packages/keras/2.2.4/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0210 13:42:01.951671 139697577998080 deprecation_wrapper.py:119] From /opt/python-3.6-packages/keras/2.2.4/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0210 13:42:01.955879 139697577998080 deprecation_wrapper.py:119] From /opt/python-3.6-packages/keras/2.2.4/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0210 13:42:01.999218 139697577998080 deprecation_wrapper.py:119] From /opt/python-3.6-packages/keras/2.2.4/keras/backend/tensorflow_backend.py:3976: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.

W0210 13:42:02.089141 139697577998080 deprecation_wrapper.py:11

In [6]:
mode = 'test'

# 'ori': original (Nifti) data
# 'box': preprocessed box data (numpy array)
load_way = 'ori'

# If the data_type is 'tumor', the patient-based ROI will be '2', or it will be '1'
data_type = 'tumor'

# Evaluate external data

In [7]:
data_list = os.listdir('../example_data/image/')

In [8]:
test_prediction = pd.DataFrame(
    columns=['case_id', 'type', 'detected_patches', 'total_patches', 'prediction', 'tp', 'fn', 'fp', 'tn'])
index = 0
patient_gt = []
patient_pd = []
patch_gt = []
patch_pd = []
datagen = DataGenerator(patch_size, data_type='msd')
for case_id in tqdm(data_list):
    if load_way == 'ori':
        img, lbl = datagen.load_image(case_id)
        box_img, box_pan, box_les = datagen.get_boxdata(img, lbl)
        image, pancreas, lesion = datagen.preprocessing(box_img, box_pan, box_les)
    elif load_way == 'ori_box':
        box_img, box_pan, box_les = datagen.load_box(case_id, config['dataset']['box_dir'])
        image, pancreas, lesion = datagen.preprocessing(box_img, box_pan, box_les)
    elif load_way == 'box':
        image, pancreas, lesion = datagen.load_box(case_id, config['dataset']['box_dir'])

    datagen.generate_patch(image, pancreas, lesion)
    datagen.get_prediction(model, patch_threshold=patch_threshold)

    tp, fn, fp, tn = datagen.get_all_value()
    gt_pancreas_patches = datagen.gt_pancreas_num()
    print(case_id, tp, fn, fp, tn)

    probs, Y = datagen.get_probs()
    patch_gt.extend(Y)
    patch_pd.extend(list(probs.T[0]))
    if data_type == 'tumor':
        gt_lesion_patches = datagen.gt_lesion_num()
        test_prediction.loc[index] = [case_id, 'tumor', tp, gt_lesion_patches, 
                                         tp / gt_lesion_patches, tp, fn, fp, tn]
        patient_gt.append(1)
        patient_pd.append(tp / gt_lesion_patches)
    else:
        gt_pancreas_patches = datagen.gt_pancreas_num()
        test_prediction.loc[index] = [case_id, 'healthy', fp, gt_pancreas_patches,
                                         fp / gt_pancreas_patches, tp, fn, fp, tn]
        patient_gt.append(0)
        patient_pd.append(fp / gt_pancreas_patches)
    index = index + 1

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

pancreas_006.nii.gz 197 24 437 132
pancreas_005.nii.gz 37 17 112 339
pancreas_004.nii.gz 187 32 378 124
pancreas_001.nii.gz 363 36 259 26



In [None]:
print(test_prediction)