In [1]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline

import gc
from tqdm import tqdm

import numpy as np

import tensorflow as tf

import medim
import tfmod
from data_loaders import Brats2015, Brats2017

from utils import *
from tools import train_model, find_threshold, get_stats_and_dices

%load_ext autoreload
%autoreload 2

%load_ext line_profiler

In [2]:
n_epoch = 100

kernel_size = 3

patch_size_x_det = np.array([25, 25, 25])
patch_size_x_con = np.array([19, 19, 19]) * 3

patch_size_y = np.array([5, 5, 5])

x_det_padding = (patch_size_x_det - patch_size_y) // 2
x_con_padding = (patch_size_x_con - patch_size_y) // 2

def make_val_inputs(x, y):
    val_inputs = []
    for xo, yo in zip(x, y):
        val_inputs.extend(get_val_deepmedic(xo, yo, x_det_padding, x_con_padding))
    return val_inputs

def get_pred_and_true(model_controller, xo, yo):
    inputs = get_val_deepmedic(xo, yo, x_det_padding, x_con_padding)
    
    yo_pred = []
    yo_true = []
    for xo_det, xo_con, yo_con in inputs:
        yo_pred.append(model_controller.predict_proba([xo_det, xo_con]))
        yo_true.append(yo_con)
        
    yo_pred = medim.utils.combine(yo_pred, [1, 1, 2, 1])
    yo_true = medim.utils.combine(yo_true, [1, 2, 1])
    return yo_pred, yo_true

def make_batch_iterator(x, y, batch_size):
    train_iter = medim.batch_iter.patch.foreground(
        [x, x, y], [patch_size_x_det, patch_size_x_con,patch_size_y],
        batch_size=batch_size,
        spatial_dims=(-3, -2, -1), f_fraction=0.5, f_condition=lambda y: y > 0,
    )
    return train_iter

In [3]:
log_path = '/tmp/tf'
processed_path = '/mount/export/brats2017/processed'
data_loader = Brats2017(processed_path)

patients = data_loader.patients
metadata = data_loader.metadata

spatial_size = data_loader.spatial_size
n_modalities = data_loader.n_modalities
n_chans_msegm = data_loader.n_chans_msegm
n_classes = data_loader.n_classes

train, train_val, val, test = split_data(metadata)

sd_train = metadata.iloc[train].survival_class.values
sd_train_val = metadata.iloc[train_val].survival_class.values
sd_val = metadata.iloc[val].survival_class.values
sd_test = metadata.iloc[test].survival_class.values

In [None]:
for ds, (val_set, test_set) in enumerate([[val, test], [test, val]]):
    for i in range(2):
        tf.reset_default_graph()
        model = tfmod.models.DeepMedic(n_modalities, n_classes)
        model_controller = tfmod.ModelController(
            model, log_path, restore_ckpt_path=None)

        
        train_idx = train + val_set
        test_idx = test_set
        
        train_model(model_controller, make_batch_iterator, make_val_inputs, data_loader,
                    train_idx, train_val, n_epoch=n_epoch)

        ckpt_path = './checkpoints/deepmedic_ds{}_i{}'.format(ds, i)
        
        saver = tf.train.Saver()
        saver.save(model_controller.session, ckpt_path)

Loading train data


100%|██████████| 201/201 [00:10<00:00, 19.93it/s]

Loading val data



100%|██████████| 5/5 [00:00<00:00, 60.94it/s]


Starting training
Epoch 0
Train: 0.477661
Val  : 0.0727244


Epoch 1
Train: 0.386076
Val  : 0.0510379


Epoch 2
Train: 0.38044
Val  : 0.0518527


Epoch 3
Train: 0.367373
Val  : 0.0466552


Epoch 4
Train: 0.357965
Val  : 0.0494435


Epoch 5
Train: 0.350993
Val  : 0.0578951


Epoch 6
Train: 0.349023
Val  : 0.0593572


Epoch 7
Train: 0.354617
Val  : 0.050134


Epoch 8
Train: 0.345927
Val  : 0.0533569


Epoch 9
Train: 0.328811
Val  : 0.0602118


Epoch 10
Train: 0.329306
Val  : 0.0712461


Epoch 11
Train: 0.332963
Val  : 0.0455981


Epoch 12
Train: 0.314303
Val  : 0.0471626


Epoch 13
Train: 0.309911
Val  : 0.0431976


Epoch 14
Train: 0.310848
Val  : 0.0463591


Epoch 15
Train: 0.303386
Val  : 0.040426


Epoch 16
Train: 0.305941
Val  : 0.0465811


Epoch 17
Train: 0.308396
Val  : 0.0516438


Epoch 18
Train: 0.295223
Val  : 0.0466127


Epoch 19
Train: 0.302727
Val  : 0.044664


Epoch 20
Train: 0.28589
Val  : 0.0421595


Epoch 21
Train: 0.286503
Val  : 0.0596751


Epoch 22
Train: 0.299428
Val 

100%|██████████| 201/201 [00:40<00:00,  4.37it/s]

Loading val data



100%|██████████| 5/5 [00:01<00:00,  3.32it/s]


Starting training
Epoch 0
Train: 0.479901
Val  : 0.0547631


Epoch 1
Train: 0.404338
Val  : 0.0758471


Epoch 2
Train: 0.391734
Val  : 0.0519936


Epoch 3
Train: 0.37366
Val  : 0.055003


Epoch 4
Train: 0.367022
Val  : 0.0506675


Epoch 5
Train: 0.356299
Val  : 0.0765495


Epoch 6
Train: 0.353944
Val  : 0.145635


Epoch 7
Train: 0.347365
Val  : 0.0509335


Epoch 8
Train: 0.337988
Val  : 0.051526


Epoch 9
Train: 0.3326
Val  : 0.0429982


Epoch 10
Train: 0.338001
Val  : 0.048546


Epoch 11
Train: 0.340128
Val  : 0.048509


Epoch 12
Train: 0.338562
Val  : 0.0419999


Epoch 13
Train: 0.328221
Val  : 0.0413538


Epoch 14
Train: 0.321392
Val  : 0.0411788


Epoch 15
Train: 0.322244
Val  : 0.0451424


Epoch 16
Train: 0.311659
Val  : 0.041934


Epoch 17
Train: 0.31656
Val  : 0.0391717


Epoch 18
Train: 0.312871
Val  : 0.0445733


Epoch 19
Train: 0.309822
Val  : 0.0367793


Epoch 20
Train: 0.31054
Val  : 0.042032


Epoch 21
Train: 0.297076
Val  : 0.0523524


Epoch 22
Train: 0.30087
Val  : 0.043

100%|██████████| 196/196 [00:35<00:00,  3.07it/s]

Loading val data



100%|██████████| 5/5 [00:01<00:00,  3.11it/s]


Starting training
Epoch 0
Train: 0.494547


In [4]:
all_dices = []

for ds, (val_set, test_set) in enumerate([[val, test], [test, val]]):
    for i in range(2):
        ckpt_path = './checkpoints/deepmedic_ds{}_i{}'.format(ds, i)
        tf.reset_default_graph()
        model = tfmod.models.DeepMedic(n_modalities, n_classes)
        model_controller = tfmod.ModelController(
            model, log_path, restore_ckpt_path=ckpt_path)

        test_idx = test_set
        
        threshold = find_threshold(model_controller, get_pred_and_true, data_loader, train_val)
        print('Treshold:', threshold)

        stats_pred, stats_true, dices = get_stats_and_dices(
        model_controller, get_pred_and_true, data_loader, test_idx, threshold)
        
        all_dices.append(dices)
        print(ckpt_path, np.mean(dices, axis=0))

INFO:tensorflow:Restoring parameters from ./checkpoints/deepmedic_ds0_i0
Loading val data


100%|██████████| 5/5 [00:00<00:00, 49.89it/s]

Starting prediction





Treshold: [ 0.72947368  0.62526316  0.88578947]
Loading val data


100%|██████████| 79/79 [00:01<00:00, 55.61it/s]

Starting prediction



79it [06:50,  5.21s/it]


./checkpoints/deepmedic_ds0_i0 [ 0.84210196  0.83000463  0.68612976]
INFO:tensorflow:Restoring parameters from ./checkpoints/deepmedic_ds0_i1
Loading val data


100%|██████████| 5/5 [00:00<00:00, 63.48it/s]

Starting prediction





Treshold: [ 0.78157895  0.78157895  0.88578947]
Loading val data


100%|██████████| 79/79 [00:01<00:00, 64.12it/s]

Starting prediction



79it [06:51,  5.24s/it]


./checkpoints/deepmedic_ds0_i1 [ 0.83893755  0.77068328  0.66020181]
INFO:tensorflow:Restoring parameters from ./checkpoints/deepmedic_ds1_i0
Loading val data


100%|██████████| 5/5 [00:00<00:00, 74.44it/s]

Starting prediction





Treshold: [ 0.72947368  0.41684211  0.83368421]
Loading val data


100%|██████████| 84/84 [00:01<00:00, 64.14it/s]

Starting prediction



84it [07:17,  5.24s/it]


./checkpoints/deepmedic_ds1_i0 [ 0.85993372  0.8472613   0.70158777]
INFO:tensorflow:Restoring parameters from ./checkpoints/deepmedic_ds1_i1
Loading val data


100%|██████████| 5/5 [00:00<00:00, 77.83it/s]

Starting prediction





Treshold: [ 0.78157895  0.72947368  0.93789474]
Loading val data


100%|██████████| 84/84 [00:01<00:00, 64.43it/s]

Starting prediction



84it [07:16,  5.21s/it]

./checkpoints/deepmedic_ds1_i1 [ 0.8684001   0.84442664  0.6780873 ]





In [5]:
[np.mean(a, axis=0) for a in all_dices]

[array([ 0.84210196,  0.83000463,  0.68612976]),
 array([ 0.83893755,  0.77068328,  0.66020181]),
 array([ 0.85993372,  0.8472613 ,  0.70158777]),
 array([ 0.8684001 ,  0.84442664,  0.6780873 ])]