In [1]:
from DeepMRI import DeepMRI
import tensorflow as tf
import numpy as np
import SegAN_IO_arch as seganio
import SegAN_arch as seganorig
import glob
import dataset_helpers as dh
import os

In [2]:
# Run this for original SegAN
models_path = 'models/'
checkpoint_basename = 'best_dice_score_'
model_name = "Segan_TF2_brats_ALL"

model_epoch = 1386
mri_types=["MR_T1", "MR_T1c", "MR_T2", "MR_Flair"]
architecture=seganorig
dataset_name = 'brats2015_train_full_challenge_crop_mri'

dataset='brats'
dataset_center_crop = [180, 180, 128] # Cropping done at dataset generation
network_center_crop = [160, 160, len(mri_types)] # Cropping done to each slice
out_folder = 'results/' + dataset_name + '/' + model_name + '/'
description = model_name + '_' + dataset_name
os.makedirs(out_folder, exist_ok=True)

In [2]:
# Run this for our SegAN
models_path = 'models/'
checkpoint_basename = 'best_dice_score_'
model_name = "Segan_IO_TF2_brats_ALL"

model_epoch = 1123
mri_types=["MR_T1", "MR_T1c", "MR_T2", "MR_Flair"]
architecture=seganio
dataset_name = 'brats2015_train_full_challenge_mri'

dataset='brats'
# dataset_center_crop = [180, 180, 128] # Cropping done at dataset generation
# network_center_crop = [160, 160, len(mri_types)] # Cropping done to each slice
dataset_center_crop = None
network_center_crop = None
out_folder = 'results/' + dataset_name + '/' + model_name + '/'
description = model_name + '_' + dataset_name
os.makedirs(out_folder, exist_ok=True)

In [3]:
size = network_center_crop[0] if network_center_crop is not None else 240
model_checkpoint = glob.glob(models_path + model_name + '/' + checkpoint_basename + str(model_epoch) + '*.index')[0].replace('.index', '')
model = DeepMRI(batch_size=64, size=size, mri_channels=len(mri_types), model_name=model_name)
#model.load_dataset(dataset=dataset, mri_types=mri_types)
model.build_model(load_model=model_checkpoint, seed=1234567890, arch=architecture)

Using architecture: SegAN_arch
Loading models/Segan_TF2_brats_ALL/best_dice_score_1386-13
Loaded model from: models/Segan_TF2_brats_ALL/best_dice_score_1386-13, next epoch: 1387


In [4]:
challenge_dataset = lambda: dh.load_dataset('../datasets/'+dataset_name,
                                                mri_type=mri_types,
                                                has_ground_truth=False,
                                                center_crop=network_center_crop,
                                                batch_size=64,
                                                prefetch_buffer=1,
                                                clip_labels_to=1.0,
                                                infinite=False, 
                                                cache=False,
                                                shuffle=False
                                            )

In [5]:
def pad_volume(volume, original_dimension):
    '''
    Volume and original dimensions must be of shape (Z, X, Y)
    '''
    if np.all(np.array(volume.shape) == original_dimension):
        return volume
    pads = []
    
    for full, crop in zip(original_dimension, volume.shape):
        pad = (full - crop)/2
        if pad % int(pad) == 0:
            pads.append((int(pad), int(pad)))
        else:
            pads.append((int(pad - .5), int(pad + .5)))
    #print("Volume shape: " + str(volume.shape))
    #print(pads)
    return np.pad(volume, pads, mode='constant')

def build_output_filename(flair_path, description):
    flair_id = path.decode('utf-8').split('.')[-2]
    return out_folder + 'VSD.' + description + '.' + str(flair_id) + '.mha'

results = {}

for b, row in enumerate(challenge_dataset()):
    pred = np.greater_equal(model.generator(row['mri'], training=False).numpy(), 0.5).astype(np.int32)
    
    for i, path in enumerate(row['MR_Flair_path'].numpy()):
        if path not in results:
            results[path] = {'pred':list(), 'origins':list(), 'spacing':list(), 'original_dimensions':list()}
        results[path]['pred'].append(pred[i])
        results[path]['origins'].append([row[mri_types[0]+'_'+ax+'_origin_src'].numpy()[i] for ax in ['z', 'x', 'y']])
        results[path]['spacing'].append([row[mri_types[0]+'_'+ax+'_spacing_src'].numpy()[i] for ax in ['z', 'x', 'y']])
        results[path]['original_dimensions'].append([row[mri_types[0]+'_'+ax+'_dimension_src'].numpy()[i] for ax in ['z', 'x', 'y']])
        print('\rSlice: {} batch: {} done'.format(i, b), end='')
print('\nPredictions done. Repacking Volumes...')

for path, res in results.items():
    # Format is (Z, X, Y)
    volume = np.stack(res['pred'])
    orig_dims = res['original_dimensions'][0]
    # Collapse the label dimension
    if volume.shape[-1] == 1:
        # Binary label
        volume = volume.squeeze()
    else:
        # Argmax
        volume = volume.argmax(axis=-1)
    if dataset_center_crop or network_center_crop:
        # Using original dimension to pad the images
        volume = pad_volume(volume, orig_dims)
        assert np.all(np.array(volume.shape) == orig_dims)
    origin = np.array(res['origins'][0], dtype=np.float64)
    spacing = np.array(res['spacing'][0], dtype=np.float64)
    
    filename = build_output_filename(path, description)
    dh.save_itk(volume, origin, spacing, filename)
    print('Saved MHA of shape {} at {}'.format(volume.shape, filename))
    
    

Slice: 63 batch: 547 done
Predictions done. Repacking Volumes...
Saved MHA of shape (155, 240, 240) at results/brats2015_train_full_challenge_crop_mri/Segan_TF2_brats_ALL/VSD.Segan_TF2_brats_ALL_brats2015_train_full_challenge_crop_mri.54512.mha
Saved MHA of shape (155, 240, 240) at results/brats2015_train_full_challenge_crop_mri/Segan_TF2_brats_ALL/VSD.Segan_TF2_brats_ALL_brats2015_train_full_challenge_crop_mri.54518.mha
Saved MHA of shape (155, 240, 240) at results/brats2015_train_full_challenge_crop_mri/Segan_TF2_brats_ALL/VSD.Segan_TF2_brats_ALL_brats2015_train_full_challenge_crop_mri.54524.mha
Saved MHA of shape (155, 240, 240) at results/brats2015_train_full_challenge_crop_mri/Segan_TF2_brats_ALL/VSD.Segan_TF2_brats_ALL_brats2015_train_full_challenge_crop_mri.54530.mha
Saved MHA of shape (155, 240, 240) at results/brats2015_train_full_challenge_crop_mri/Segan_TF2_brats_ALL/VSD.Segan_TF2_brats_ALL_brats2015_train_full_challenge_crop_mri.54536.mha
Saved MHA of shape (155, 240, 240) 

In [79]:
# Metrics checking
import glob
import dataset_helpers as dh
import os
import pandas as pd
import numpy as np

def binary_stats(true, pred, thr=0.5):
    true = np.greater(true, thr)
    pred = np.greater(pred, thr)
    
    tp = np.count_nonzero(np.logical_and(true, pred))
    tn = np.count_nonzero(np.logical_and(np.logical_not(true), np.logical_not(pred)))
    fp = np.count_nonzero(np.logical_and(np.logical_not(true), pred))
    fn = np.count_nonzero(np.logical_and(true, np.logical_not(pred)))
    
    return tp, tn, fp, fn


MAX_LABELS = 1
results = pd.DataFrame()
for i, pred_path in enumerate(glob.glob('results/brats2015_train_full_challenge_crop_mri/Segan_TF2_brats_ALL/*.mha')):
    index = pred_path.split('.')[-2]
    # Find the corresponding GT
    gt_patient_folder = ['/'.join(os.path.dirname(f).split('/')[:-1]) for f in glob.glob('../../datasets/BRATS2015/BRATS2015_Training/*/*/*/*.mha') if index in f][0]
    gt_path = [f for f in glob.glob(gt_patient_folder+'/*/*.mha') if 'OT' in f][0]
    gt, _, _ = dh.load_itk(gt_path)
    pred, _, _ = dh.load_itk(pred_path)
    tp, tn, fp, fn = binary_stats(gt, pred)
    results = results.append({'gt':gt_path, 'pred':pred_path, 'tp':tp, 'tn':tn, 'fp':fp, 'fn':fn}, ignore_index=True)
    print('\r'+str(i), end="")
    

273

In [80]:
results['dice'] = 2*results['tp']/(2*results['tp']+results['fp']+results['fn'])
results['sensitivity'] = results['tp']/(results['tp']+results['fn'])
results['specificity'] = results['tn']/(results['tn']+results['fp'])

In [81]:
results.mean()*100

fn             2.490427e+06
fp             8.220266e+05
tn             8.811074e+08
tp             8.380181e+06
dice           8.051681e+01
sensitivity    7.545045e+01
specificity    9.990678e+01
dtype: float64