In [1]:
import tensorflow as tf
import numpy as np
import skimage.io as scio
from ipywidgets import interact, fixed, IntSlider, FloatSlider
import seaborn as sb
import pandas as pd

## Load the right network

In [33]:
import dataset_helpers as dh
batchsize = 64
# These shapes are considered after cropping
mri_shape = [batchsize, 160, 160, 3]
seg_shape = [batchsize, 160, 160, 1]
def load_dataset():
    validation_dataset_name = 'brats2015-Train-all_validation_crop_mri'
    validation_dataset = dh.load_dataset('../datasets/brats2015-Train-all_validation_crop_mri',
                                     mri_type=['MR_T1c', 'MR_T2', 'MR_Flair'],
                                     center_crop=mri_shape[1:],
                                     batch_size=batchsize,
                                     prefetch_buffer=1,
                                     clip_labels_to=1.0,
                                     infinite = False,
                                     interleave=1,
                                     shuffle=False
                                )
    return validation_dataset


def load_network(run_name, config):
    from SegAN_IO import SegAN_IO
    net = SegAN_IO(mri_shape, seg_shape, config=config, run_name=run_name)
    return net
    
    

In [31]:
def predict_from_validation(run_name, checkpoint):
    '''
    :param input_mri: NDArray of shape [batch_size, H, W, channels]
    '''

    result_list = list()
    
    # Creating the iterator
    tf.reset_default_graph()
    validation_dataset = load_dataset()
    iterator = validation_dataset.make_one_shot_iterator()
    next_batch = iterator.get_next()

    
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        net = load_network(run_name, config)
        # Load (build) the network from the checkpoint
        net.build_network(next_batch['mri'], next_batch['seg'], session=sess, load_checkpoint=checkpoint)
        # Running evaluations on the entire dataset
        while True:
            try:
                results = sess.run({'prediction': net.layers['S']['out'], 'input': net.layers['in']['mri'], 'ground_truth': net.layers['in']['seg']}, feed_dict={net.layers['in']['training']: False})  
                result_list.append(results)
            except tf.errors.OutOfRangeError:
                break
    return result_list
    
    
    
def evaluate_all_checkpoints(run_name, checkpoint_folder, mri_shape, seg_shape):
    import dataset_helpers as dh
    import pandas as pd
           
    all_checkpoints = tf.train.get_checkpoint_state(checkpoint_folder).all_model_checkpoint_paths
    result_list = list()
    for checkpoint in all_checkpoints:
        # Creating the iterator
        tf.reset_default_graph()
        iterator = validation_dataset.make_one_shot_iterator()
        next_batch = iterator.get_next()
        
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            net = load_network(run_name, config)
            # Load (build) the network from the checkpoint
            net.build_network(next_batch['mri'], next_batch['seg'], session=sess, load_checkpoint=checkpoint)
            # Running evaluations on the entire dataset
            while True:
                try:
                    results = sess.run(net.layers['eval'], feed_dict={net.layers['in']['training']: False})
                    results['checkpoint'] = checkpoint
                    result_list.append(results)
                except tf.errors.OutOfRangeError:
                    break
                    
    # Make a DataFrame with all the results
    pd_eval = pd.DataFrame(result_list)
    # Save to CSV
    pd_eval.to_csv('eval_{}.csv'.format(run_name))

In [4]:
# Generate stats for every batch and save them in a csv file (takes about an hour)
#evaluations = evaluate_all_checkpoints('loss_fixed_26_feb', mri_shape=[32, 240, 240, 1], seg_shape=[32, 240, 240, 1])

In [None]:
import pandas as pd
ev = pd.read_csv('eval_loss_fixed_26_feb.csv', index_col = 0)
ev["batch_n"] =ev.groupby("checkpoint").cumcount() # Create a column for the batch enumeration
ev = ev.set_index(["checkpoint", "batch_n"]) # Set the new indices
ev

In [None]:
# Check which metrics make sense to consider by cheching the global stability
ev.var()
# 26 feb run showed that the metrics with most variance were sensitivity (2%), dice_score (1.3%), precision (1.2%). Balanced accuracy was at (0.5%) and others were mostly stable.

In [None]:
# 
evstats = ev.groupby(level="checkpoint").agg(['mean', 'var'])
top_models = list()
for m in ["dice_score", "sensitivity", "precision", "balanced_accuracy"]:
    top = evstats[m].nlargest(1, "mean")
    print("Best {} model: {}".format(m, top.index[0]))
    top_models.append(top.index[0])
evstats.loc[top_models,:]

In [None]:
# Creating a version for plotting (Not actually feasible for too many checkpoints)
evp = ev.loc[top_models,:].stack().reset_index().rename(columns={'level_2':'metric', 0:'value'})
sb.catplot(x="value", y="checkpoint", col="metric", col_wrap=2, kind="box", data=evp, orient="h")

## Select the best checkpoint
according to the validation set dice score

In [5]:
dicescore = pd.read_csv('run_logs/18-apr-run_test-tag-avg_dice_score.csv')
max_dice = dicescore.Value.max()
best_step = int(dicescore.iloc[dicescore.Value.idxmax()].Step)
print("Best step is {} with a dice score of {}".format(best_step, max_dice))

Best step is 53500 with a dice score of 0.8258382678031921


# Generate some predictions from validation set
Load one checkpoint and predict

In [34]:
model = "../models/SegAN/SeganIO_18Apr_model/model.ckpt-51500"
results = predict_from_validation(run_name='SeganIO_18Apr', checkpoint=model)

INFO:tensorflow:Restoring parameters from ../models/SegAN/SeganIO_18Apr_model/model.ckpt-51500
Loaded model from ../models/SegAN/SeganIO_18Apr_model/ at global step 51500


In [35]:
import matplotlib.pyplot as plt
def visualize_results(results, batch_n, sample_n, gt_alpha, pred_alpha):
    input_sample = results[batch_n]['input'][sample_n, :,:,0]
    prediction = results[batch_n]['prediction'][sample_n, :,:,0]
    ground_truth = results[batch_n]['ground_truth'][sample_n, :,:,0]
    
    blank_channel = np.zeros([input_sample.shape[0], input_sample.shape[1]])
    
    # Make prediction and ground truth RGBA
    prediction = np.expand_dims(prediction, axis=-1)
    # Create a greyscale image with alpha channel
    prediction = prediction.repeat(4, axis=-1)
    # Make the prediction red
    prediction[:,:,1] = blank_channel
    prediction[:,:,2] = blank_channel
    # Same for the ground truth
    ground_truth = np.expand_dims(ground_truth, axis=-1)
    ground_truth = ground_truth.repeat(4, axis=-1)
    ground_truth[:,:,0] = blank_channel
    ground_truth[:,:,2] = blank_channel
    
    plt.figure(figsize=(8, 6), dpi=120)
    plt.imshow(input_sample, cmap="binary_r")
    plt.imshow(ground_truth, alpha=gt_alpha, cmap="Reds_r")
    plt.imshow(prediction, alpha=pred_alpha, cmap="Greens")
    


interact(visualize_results, results=fixed(results), batch_n=IntSlider(min=0,max=len(results)-1,step=1,value=0), sample_n=IntSlider(min=0,max=31,step=1,value=0), gt_alpha=FloatSlider(min=0,max=1,step=0.1,value=0), pred_alpha=FloatSlider(min=0,max=1,step=0.1,value=0));

interactive(children=(IntSlider(value=0, description='batch_n', max=55), IntSlider(value=0, description='sampl…