In [1]:
from DeepMRI import DeepMRI
import tensorflow as tf
import numpy as np
import SegAN_IO_arch as seganio
import SegAN_arch as seganorig
import glob
import dataset_helpers as dh
import os

In [2]:
models_path = 'models/'
checkpoint_basename = 'best_dice_score_'
model_name = "Segan_IO_TF2_brats_ALL"

model_epoch = 1123
mri_types=["MR_T1", "MR_T1c", "MR_T2", "MR_Flair"]
architecture=seganio
dataset='brats'
dataset_name = 'brats2015_testing_challenge_crop_mri'
dataset_center_crop = [180, 180, 128] # Cropping done at dataset generation
network_center_crop = [160, 160, len(mri_types)] # Cropping done to each slice
out_folder = 'results/' + dataset_name + '/' + model_name + '/'
description = model_name + '_' + dataset_name
os.makedirs(out_folder, exist_ok=True)

In [3]:
model_checkpoint = glob.glob(models_path + model_name + '/' + checkpoint_basename + str(model_epoch) + '*.index')[0].replace('.index', '')
model = DeepMRI(batch_size=64, size=network_center_crop[0], mri_channels=len(mri_types), model_name=model_name)
#model.load_dataset(dataset=dataset, mri_types=mri_types)
model.build_model(load_model=model_checkpoint, seed=1234567890, arch=architecture)

Using architecture: SegAN_IO_arch
Loaded history from models/Segan_IO_TF2_brats_ALL/log_train.csv
Loaded history from models/Segan_IO_TF2_brats_ALL/log_valid.csv
Loading models/Segan_IO_TF2_brats_ALL/best_dice_score_1123-51
Loaded model from: models/Segan_IO_TF2_brats_ALL/best_dice_score_1123-51, next epoch: 1272


In [4]:
challenge_dataset = lambda: dh.load_dataset('../datasets/'+dataset_name,
                                                mri_type=mri_types,
                                                has_ground_truth=False,
                                                center_crop=network_center_crop,
                                                batch_size=64,
                                                prefetch_buffer=1,
                                                clip_labels_to=1.0,
                                                infinite=False, 
                                                cache=False,
                                                shuffle=False
                                            )

In [None]:
def pad_volume(volume, original_dimension):
    '''
    Volume and original dimensions must be of shape (Z, X, Y)
    '''
    pads = []
    
    for full, crop in zip(original_dimension, volume.shape):
        pad = (full - crop)/2
        if pad % int(pad) == 0:
            pads.append((int(pad), int(pad)))
        else:
            pads.append((int(pad - .5), int(pad + .5)))
    print("Volume shape: " + str(volume.shape))
    print(pads)
    return np.pad(volume, pads, mode='constant')

def build_output_filename(flair_path, description):
    flair_id = path.decode('utf-8').split('.')[-2]
    return out_folder + 'VSD.' + description + '.' + str(flair_id) + '.mha'

results = {}

for b, row in enumerate(challenge_dataset()):
    pred = model.generator(row['mri'], training=False).numpy()
    
    for i, path in enumerate(row['MR_Flair_path'].numpy()):
        if path not in results:
            results[path] = {'pred':list(), 'origins':list(), 'spacing':list(), 'original_dimensions':list()}
        results[path]['pred'].append(pred[i])
        results[path]['origins'].append([row[mri_types[0]+'_'+ax+'_origin_src'].numpy()[i] for ax in ['z', 'x', 'y']])
        results[path]['spacing'].append([row[mri_types[0]+'_'+ax+'_spacing_src'].numpy()[i] for ax in ['z', 'x', 'y']])
        results[path]['original_dimensions'].append([row[mri_types[0]+'_'+ax+'_dimension_src'].numpy()[i] for ax in ['z', 'x', 'y']])
        print('\rSlice: {} batch: {} done'.format(i, b), end='')
print('\nPrediction done. Repacking Volumes...')


    
    

In [15]:
for path, res in results.items():
    # Format is (Z, X, Y)
    volume = np.stack(res['pred'])
    orig_dims = res['original_dimensions'][0]
    # Collapse the label dimension
    if volume.shape[-1] == 1:
        # Binary label
        volume = volume.squeeze()
    else:
        # Argmax
        volume = volume.argmax(axis=-1)
    if dataset_center_crop or network_center_crop:
        # Using original dimension to pad the images
        volume = pad_volume(volume, orig_dims)
        assert np.all(np.array(volume.shape) == orig_dims)
    origin = np.array(res['origins'][0], dtype=np.float64)
    spacing = np.array(res['spacing'][0], dtype=np.float64)
    
    filename = build_output_filename(path, description)
    dh.save_itk(volume, origin, spacing, filename)
    print('Saved MHA of shape {} at {}'.format(volume.shape, filename))

Volume shape: (128, 160, 160)
[(13, 14), (40, 40), (40, 40)]
(155, 240, 240)
[155, 240, 240]
Saved MHA of shape (155, 240, 240) at results/brats2015_training_challenge_crop_mri/Segan_IO_TF2_brats_ALL/VSD.Segan_IO_TF2_brats_ALL_brats2015_training_challenge_crop_mri.54193.mha
Volume shape: (128, 160, 160)
[(13, 14), (40, 40), (40, 40)]
(155, 240, 240)
[155, 240, 240]
Saved MHA of shape (155, 240, 240) at results/brats2015_training_challenge_crop_mri/Segan_IO_TF2_brats_ALL/VSD.Segan_IO_TF2_brats_ALL_brats2015_training_challenge_crop_mri.54199.mha
Volume shape: (128, 160, 160)
[(13, 14), (40, 40), (40, 40)]
(155, 240, 240)
[155, 240, 240]
Saved MHA of shape (155, 240, 240) at results/brats2015_training_challenge_crop_mri/Segan_IO_TF2_brats_ALL/VSD.Segan_IO_TF2_brats_ALL_brats2015_training_challenge_crop_mri.54205.mha
Volume shape: (128, 160, 160)
[(13, 14), (40, 40), (40, 40)]
(155, 240, 240)
[155, 240, 240]
Saved MHA of shape (155, 240, 240) at results/brats2015_training_challenge_crop_mr

In [9]:
spacing[[0,2,1]]

array([1., 1., 1.], dtype=float32)

In [12]:
im, orig, sp = dh.load_itk(path.decode('utf-8'))

In [14]:
sp.dtype

dtype('float64')