In [1]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline

import gc
from tqdm import tqdm

import numpy as np

import tensorflow as tf

import medim
import tfmod
from data_loaders import Brats2015, Brats2017

from utils import *
from tools import train_model, find_threshold, get_stats_and_dices

%load_ext autoreload
%autoreload 2

%load_ext line_profiler

In [2]:
n_epoch = 100

kernel_size = 3

patch_size_x_det = np.array([25, 25, 25])
patch_size_x_con = np.array([19, 19, 19]) * 3

patch_size_y = np.array([5, 5, 5])

x_det_padding = (patch_size_x_det - patch_size_y) // 2
x_con_padding = (patch_size_x_con - patch_size_y) // 2

def make_val_inputs(x, y):
    val_inputs = []
    for xo, yo in zip(x, y):
        val_inputs.extend(get_val_deepmedic(xo, yo, x_det_padding, x_con_padding))
    return val_inputs

def get_pred_and_true(model_controller, xo, yo):
    inputs = get_val_deepmedic(xo, yo, x_det_padding, x_con_padding)
    
    yo_pred = []
    yo_true = []
    for xo_det, xo_con, yo_con in inputs:
        yo_pred.append(model_controller.predict_proba([xo_det, xo_con]))
        yo_true.append(yo_con)
        
    yo_pred = medim.utils.combine(yo_pred, [1, 1, 2, 1])
    yo_true = medim.utils.combine(yo_true, [1, 2, 1])
    return yo_pred, yo_true

def make_batch_iterator(x, y, batch_size):
    train_iter = medim.batch_iter.patch.foreground(
        [x, x, y], [patch_size_x_det, patch_size_x_con,patch_size_y],
        batch_size=batch_size,
        spatial_dims=(-3, -2, -1), f_fraction=0.5, f_condition=lambda y: y > 0,
    )
    return train_iter

In [4]:
log_path = '/tmp/tf'
processed_path = '/mount/export/brats2017/processed'
data_loader = Brats2017(processed_path)

patients = data_loader.patients
metadata = data_loader.metadata

spatial_size = data_loader.spatial_size
n_modalities = data_loader.n_modalities
n_chans_msegm = data_loader.n_chans_msegm
n_classes = data_loader.n_classes

train, train_val, val, test = split_data(metadata)

sd_train = metadata.iloc[train].survival_class.values
sd_train_val = metadata.iloc[train_val].survival_class.values
sd_val = metadata.iloc[val].survival_class.values
sd_test = metadata.iloc[test].survival_class.values

In [5]:
for ds, (val_set, test_set) in enumerate([[val, test], [test, val]]):
    for i in range(2):
        tf.reset_default_graph()
        model = tfmod.models.DeepMedic(n_modalities, n_classes)
        model_controller = tfmod.ModelController(
            model, log_path, restore_ckpt_path=None)

        
        train_idx = train + val_set
        test_idx = test_set
        
        train_model(model_controller, make_batch_iterator, make_val_inputs, data_loader,
                    train_idx, train_val, n_epoch=n_epoch)

        ckpt_path = './checkpoints/deepmedic_ds{}_i{}'.format(ds, i)
        
        saver = tf.train.Saver()
        saver.save(model_controller.session, ckpt_path)

Loading train data


100%|██████████| 201/201 [00:44<00:00,  3.15it/s]

Loading val data



100%|██████████| 5/5 [00:00<00:00,  7.67it/s]


Starting training
Epoch 0
Train: 0.478607
Val  : 0.0601769


Epoch 1
Train: 0.390074
Val  : 0.0497063


Epoch 2
Train: 0.394486
Val  : 0.0746902


Epoch 3
Train: 0.377737
Val  : 0.0496886


Epoch 4
Train: 0.375905
Val  : 0.0584974


Epoch 5


KeyboardInterrupt: 

In [7]:
all_dices = []

for ds, (val_set, test_set) in enumerate([[val, test], [test, val]]):
    for i in range(2):
        ckpt_path = './checkpoints/deepmedic_ds{}_i{}'.format(ds, i)
        tf.reset_default_graph()
        model = tfmod.models.DeepMedic(n_modalities, n_classes)
        model_controller = tfmod.ModelController(
            model, log_path, restore_ckpt_path=ckpt_path)

        test_idx = test_set
        
        threshold = find_threshold(model_controller, get_pred_and_true, sum_probs2017,
                                   data_loader, train_val)
        print('Treshold:', threshold)

        stats_pred, stats_true, dices = get_stats_and_dices(
            model_controller, get_pred_and_true, sum_probs2017,
            data_loader, test_idx, threshold)
        
        all_dices.append(dices)
        print(ckpt_path, np.mean(dices, axis=0))

INFO:tensorflow:Restoring parameters from ./checkpoints/deepmedic_ds0_i0
Loading val data


100%|██████████| 5/5 [00:00<00:00, 56.75it/s]

Starting prediction





Treshold: [ 0.72947368  0.62526316  0.88578947]
Loading test data


100%|██████████| 79/79 [00:10<00:00,  7.60it/s]

Starting prediction



79it [06:49,  5.23s/it]


./checkpoints/deepmedic_ds0_i0 [ 0.84210196  0.83000463  0.68612976]
INFO:tensorflow:Restoring parameters from ./checkpoints/deepmedic_ds0_i1
Loading val data


100%|██████████| 5/5 [00:00<00:00, 60.18it/s]

Starting prediction





Treshold: [ 0.78157895  0.78157895  0.88578947]
Loading test data


100%|██████████| 79/79 [00:01<00:00, 60.46it/s]

Starting prediction



79it [06:49,  5.21s/it]


./checkpoints/deepmedic_ds0_i1 [ 0.83893755  0.77068328  0.66020181]
INFO:tensorflow:Restoring parameters from ./checkpoints/deepmedic_ds1_i0
Loading val data


100%|██████████| 5/5 [00:00<00:00, 61.10it/s]

Starting prediction





Treshold: [ 0.72947368  0.41684211  0.83368421]
Loading test data


100%|██████████| 84/84 [00:01<00:00, 64.33it/s]

Starting prediction



84it [07:16,  5.18s/it]


./checkpoints/deepmedic_ds1_i0 [ 0.85993372  0.8472613   0.70158777]
INFO:tensorflow:Restoring parameters from ./checkpoints/deepmedic_ds1_i1
Loading val data


100%|██████████| 5/5 [00:00<00:00, 61.95it/s]

Starting prediction





Treshold: [ 0.78157895  0.72947368  0.93789474]
Loading test data


100%|██████████| 84/84 [00:01<00:00, 64.31it/s]

Starting prediction



84it [07:16,  5.19s/it]

./checkpoints/deepmedic_ds1_i1 [ 0.8684001   0.84442664  0.6780873 ]





In [8]:
[np.mean(a, axis=0) for a in all_dices]

[array([ 0.84210196,  0.83000463,  0.68612976]),
 array([ 0.83893755,  0.77068328,  0.66020181]),
 array([ 0.85993372,  0.8472613 ,  0.70158777]),
 array([ 0.8684001 ,  0.84442664,  0.6780873 ])]