In [1]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline

import time
import pickle
from imp import reload
from os.path import join
from itertools import product

from sklearn.model_selection import KFold
import numpy as np
from tqdm import tqdm

import tensorflow as tf

from data_loader import Brats
import medim
from models_tf import *

%load_ext autoreload
%autoreload 2

%load_ext line_profiler

In [2]:
def encode_msegm2015(s):
    r = np.zeros((3, *s.shape), dtype=bool)
    r[0] = s > 0
    r[1] = (s == 1) | (s == 3) | (s == 4)
    r[2] = (s == 4)
    return r

In [3]:
processed_path = '/home/mount/neuro-x02-ssd/brats2015/processed'
encode_msegm = encode_msegm2015

data_loader = Brats(processed_path)
patients = data_loader.patients

x = []
y = []

for patient in tqdm(patients):
    x.append(data_loader.load_mscan(patient))
    y.append(data_loader.load_segm(patient))

n_modalities = 4
n_msegm_chans = 3
n_classes = len(np.unique(y[0]))
n_classes

100%|██████████| 274/274 [00:42<00:00,  6.68it/s]


5

In [4]:
n_splits_train = 5
n_splits_val = 40

cv = KFold(n_splits_train, shuffle=True, random_state=42)
train, test = next(cv.split(y))

def extract(l, idx):
    return [l[i] for i in idx]

x_train, x_test = extract(x, train), extract(x, test)
y_train, y_test = extract(y, train), extract(y, test)

cv = KFold(n_splits_val, shuffle=True, random_state=21)
train, val = next(cv.split(x_train))

x_train, x_val = extract(x_train, train), extract(x_train, val)
y_train, y_val = extract(y_train, train), extract(y_train, val)

In [13]:
kernel_size = 3
blocks = [n_modalities, 32, 32, 64, 64]

patch_size_x = np.array([25, 25, 25])
patch_size_y = patch_size_x - 2*(len(blocks) - 1)

In [14]:
%%time

padding = (patch_size_x - patch_size_y) // 2

val_shape = np.max(list(map(np.shape, x_val)), axis=0)[1:]

def min_padding(x, padding, val_shape):
    # 3-dimentional spatial
    non_spatial = x.ndim - 3
    padding = np.array(padding)
    
    padding = list(zip(padding, padding + val_shape - np.array(x.shape[non_spatial:])))
    padding = [(0, 0)] * non_spatial + padding
    
    return np.pad(x, padding, mode='constant')

x_val_padded = [min_padding(s, padding, val_shape) for s in x_val]
y_val_padded = [min_padding(s, [0, 0, 0], val_shape) for s in y_val]

CPU times: user 352 ms, sys: 4.68 s, total: 5.03 s
Wall time: 4.88 s


In [15]:
tf.reset_default_graph()
model = EEnet(blocks, n_classes, kernel_size)

batch_size = 128

In [16]:
log_path = '/tmp/tf'
train_operation = make_training_operation(model, log_path)
val_operation = make_validation_operation(model, log_path)

session = tf.InteractiveSession()
tf.global_variables_initializer().run()

In [17]:
def train(model, batch_iter, lr, n_batches):
    losses = []
    for _ in tqdm(range(n_batches)):
        x_batch, y_batch = next(batch_iter)
        y_batch = np.int64(y_batch)
        
        loss = train_operation(x_batch, y_batch, lr, session)
        losses.append(loss)

    return np.mean(losses)

def predict_evaluate(model, x, y):
    y_pred, losses = [], []
    for xo, yo in tqdm(zip(x, y)):
        x_batch = xo[None, :]
        y_batch = np.int64(yo[None, :])
        yo_pred, loss = val_operation(x_batch, y_batch, session)
        y_pred.append(yo_pred)
        losses.append(loss)
    return y_pred, losses

In [18]:
lr = 0.1

n_epoch = 100
n_batches_per_epoch = 40

train_iter = medim.batch_iter.patch.uniform(
        [x_train, y_train], [patch_size_x, patch_size_y], batch_size=batch_size,
        spatial_dims=(-3, -2, -1), 
    )
    

for epoch in range(n_epoch):
    print('Epoch {}'.format(epoch), flush=True)
    
    train_loss = train(model, train_iter, lr, n_batches_per_epoch)
    print('Train:', train_loss, flush=True)
    
    losses = []
    dices = []
    
    y_pred, losses = predict_evaluate(model, x_val_padded, y_val_padded)
    
    dices = []
    for yo_pred, yo_true in tqdm(zip(y_pred, y_val_padded)):
        msegm_true = encode_msegm(yo_true)
        msegm_pred = encode_msegm(yo_pred)
        dices.append([medim.metrics.dice_score(msegm_pred[k], msegm_true[k])
                      for k in range(n_msegm_chans)])
    
    val_loss = np.mean(losses)
    val_dices = np.mean(dices, axis=0)

    print('Val       :', val_loss)
    print('Val dices :', val_dices)
    print('\n', flush=True)

Epoch 0


100%|██████████| 40/40 [00:17<00:00,  2.41it/s]

Train: 0.323581



6it [00:03,  1.84it/s]
6it [00:01,  4.63it/s]

Val       : 0.0620192
Val dices : [ 0.05997261  0.          0.        ]


Epoch 1



100%|██████████| 40/40 [00:16<00:00,  2.46it/s]

Train: 0.0866656



6it [00:03,  1.92it/s]
6it [00:01,  4.09it/s]

Val       : 0.0535373
Val dices : [ 0.05997261  0.          0.        ]


Epoch 2



100%|██████████| 40/40 [00:16<00:00,  2.41it/s]

Train: 0.0924303



6it [00:03,  1.96it/s]
6it [00:01,  4.10it/s]

Val       : 0.0582572
Val dices : [ 0.05997261  0.          0.        ]


Epoch 3



100%|██████████| 40/40 [00:16<00:00,  2.43it/s]

Train: 0.0826762



6it [00:03,  1.83it/s]
6it [00:01,  3.99it/s]

Val       : 0.0498347
Val dices : [ 0.05997261  0.          0.        ]


Epoch 4



100%|██████████| 40/40 [00:16<00:00,  2.42it/s]

Train: 0.0814916



6it [00:03,  1.91it/s]
6it [00:01,  4.38it/s]

Val       : 0.053347
Val dices : [ 0.05997261  0.          0.        ]


Epoch 5



100%|██████████| 40/40 [00:16<00:00,  2.40it/s]

Train: 0.0751868



6it [00:03,  1.79it/s]
6it [00:01,  3.87it/s]

Val       : 0.0557734
Val dices : [ 0.05997261  0.          0.        ]


Epoch 6



100%|██████████| 40/40 [00:16<00:00,  2.40it/s]

Train: 0.0697552



6it [00:03,  1.92it/s]
6it [00:01,  4.08it/s]

Val       : 0.0469683
Val dices : [ 0.05997261  0.          0.        ]


Epoch 7



100%|██████████| 40/40 [00:16<00:00,  2.40it/s]

Train: 0.07682



6it [00:03,  1.84it/s]
6it [00:01,  3.96it/s]

Val       : 0.043166
Val dices : [ 0.05997261  0.          0.        ]


Epoch 8



100%|██████████| 40/40 [00:16<00:00,  2.39it/s]

Train: 0.0689402



6it [00:03,  1.95it/s]
6it [00:01,  4.29it/s]

Val       : 0.0415084
Val dices : [ 0.05997261  0.          0.        ]


Epoch 9



100%|██████████| 40/40 [00:16<00:00,  2.42it/s]

Train: 0.0705614



6it [00:03,  1.81it/s]
6it [00:01,  3.74it/s]

Val       : 0.0444832
Val dices : [ 0.05997261  0.          0.        ]


Epoch 10



 15%|█▌        | 6/40 [00:02<00:14,  2.43it/s]


KeyboardInterrupt: 

In [15]:
np.mean(losses)

  ret = ret.dtype.type(ret / rcount)


nan

In [83]:
%%time

segm_log_loss(yo_pred, yo_true)

CPU times: user 632 ms, sys: 852 ms, total: 1.48 s
Wall time: 1.48 s


1.9770362883573271

In [77]:
np.mean(np.argmax(yo_pred, axis=0) == yo_true)

0.93480953059053296

In [72]:
n_msegm_chans = 3

In [None]:
n_parts = [1, 2, 1]

n_epoch = 100
batch_per_epoch = 40
batch_size = 128

for epoch in range(n_epoch):
    train_iter = medim.batch_iter.patch.uniform(
        x_train, x_train, batch_size=batch_size,
        patch_size_x=patch_size_x, patch_size_y=patch_size_y, 
    )
    
    start_train = time.time()
    
    losses = []
    for _ in tqdm(range(batch_per_epoch)):
        x_batch, y_batch = next(train_iter)
    
    end_train = time.time()
    train_loss = np.mean(losses)
    
    print('Epoch {}'.format(epoch), flush=True)
    
    start_val = time.time()
    
    losses = []
    dices = []

    for mscan, msegm in zip(x_val_padded, y_val_padded):
        msegm = np.array(msegm, dtype=np.float32)

        mscan_parts = medim.utils.divide(mscan, padding, n_parts)
        msegm_parts = medim.utils.divide(msegm, [0, 0, 0], n_parts)

        predicted = []
        true = []
        for mscan_part, msegm_part in tqdm(zip(mscan_parts, msegm_parts)):
            o = np.array(mscan_part[None, :])

            y_pred = np.array(msegm_part)

            predicted.append(y_pred)
            true.append(msegm_part)
    
        y_pred = combine_fast(predicted)
        y_true = combine_fast(true)
        
        #dices.append([dice_score(y_pred[k] > 0.5, y_true[k]) for k in range(n_classes)])
    
    end_val = time.time()
    
    val_loss = np.mean(losses)
    val_dices = np.mean(dices, axis=0)

    print('Time :', end_train - start_train, end_val - start_val)
    print('\n', flush=True)

In [None]:
y_preds = []

model.eval()
for mscan, msegm in tqdm(zip(x_val_padded, y_val)):
    msegm = np.array(msegm, dtype=np.float32)

    mscan_parts = medim.utils.divide(mscan, padding, n_parts)
    msegm_parts = medim.utils.divide(msegm, [0, 0, 0], n_parts)

    predicted = []
    for mscan_part, msegm_part in zip(mscan_parts, msegm_parts):
        o = np.array(mscan_part[None, :])

        y_pred = model(to_var(o, volatile=True))
        loss = F.binary_cross_entropy(y_pred, to_var(msegm_part[None, :], volatile=True))
        
        predicted.append(to_numpy(y_pred)[0])

    y_pred = medim.utils.combine(predicted, n_parts)
    y_preds.append(y_pred)

In [None]:
def get_dice_threshold(y_preds, my):
    thresholds = []
    for i in range(n_classes):
        ps = np.linspace(0, 1, 100)
        best_p = 0
        best_score = 0
        for p in ps:
            score = np.mean([dice_score(pred[i] > p, true[i]) for pred, true in zip(y_preds, my)])
            if score is np.nan or None:
                print('None')
                score = 1
            
            if score > best_score:
                best_p = p
                best_score = score
        thresholds.append(best_p)
        print(best_score)
    return thresholds

thresholds = get_dice_threshold(y_preds, y_val)
thresholds

In [None]:
y_preds_t = []

model.eval()
for mscan, msegm in tqdm(zip(x_test, y_test)):
    mscan = min_padding(mscan, padding)
    msegm = np.array(msegm, dtype=np.float32)

    mscan_parts = medim.utils.divide(mscan, padding, n_parts)
    msegm_parts = medim.utils.divide(msegm, [0, 0, 0], n_parts)

    predicted = []
    for mscan_part, msegm_part in zip(mscan_parts, msegm_parts):
        o = np.array(mscan_part[None, :])

        y_pred = model(to_var(o, volatile=True))
        loss = F.binary_cross_entropy(y_pred, to_var(msegm_part[None, :], volatile=True))
        
        predicted.append(to_numpy(y_pred)[0])

    y_pred = medim.utils.combine(predicted, n_parts)
    y_preds_t.append(y_pred)

In [None]:
np.mean([[dice_score(y_preds_t[i][k] > thresholds[k], y_test[i][k]) for k in range(n_classes)]
         for i in range(len(y_preds_t))], axis=0)

In [None]:
len(x_train), len(x_val), len(x_test)

In [None]:
fig, ax = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(15, 15))

i = 4
k = 90

cmap = cm.GnBu
fontsize = 30

ax[0, 0].set_title('Predicted segmentation', fontsize=fontsize)
ax[0, 0].imshow(y_preds_t[i][0, ..., k], cmap=cmap)
#plt.colorbar()
#plt.show()
ax[0, 1].set_title('Ground truth', fontsize=fontsize)
ax[0, 1].imshow(y_test[i][0, ..., k], cmap=cmap)

ax[1,0].set_title('Brain slice', fontsize=fontsize)
ax[1, 0].imshow(x_test[i][3, ..., k], cmap=cmap)
plt.tight_layout()
ax[0, 0].axis('off')
ax[1, 0].axis('off')
ax[0, 1].axis('off')
ax[1, 1].axis('off')
plt.show()
#plt.colorbar()
#plt.show()

In [10]:
def segm_log_loss(yo_pred, yo_true):
    y_pred = np.moveaxis(yo_pred, 0, 3).reshape((-1, n_classes))
    y_true = yo_true.flatten()
    return log_loss(y_true, y_pred)

In [None]:
# def pred_reshape(y):
#     x = y.permute(0, 2, 3, 4, 1)
#     return x.contiguous().view(-1, x.size()[-1])

# def loss_cross_entropy(y_pred, y_true):
#     return F.cross_entropy(pred_reshape(y_pred), y_true.view(-1))


coeff = to_var(np.array([1, 2, 3], dtype=np.float32))
epsilon = 1e-7

def dice_loss(y_pred, target):
    y_pred = y_pred.view(*y_pred.size()[:2], -1)
    target = target.view(*target.size()[:2], -1)
    
#     s = y_pred.size()
#     e = epsilon.expand(s[0], 1, s[2])
    dice_scores = 2 * (epsilon + (y_pred * target).sum(2)) / \
                  (y_pred.sum(2) + target.sum(2) + 2 * epsilon)
        
    dice_scores = dice_scores.mean(0)
    dice_scores = dice_scores.view(-1)

    return -torch.sum(dice_scores * coeff)

In [None]:
def predict_object(model, xo, n_parts_per_axis):
    model.eval()
    xo_parts = medim.utils.divide(xo, [0] + [*padding], n_parts_per_axis)

    xo_predicted = []
    for xo_part in xo_parts:
        xo_part = xo_part[None, :]

        y_pred = model(to_var(xo_part, volatile=True))

        xo_predicted.append(to_numpy(y_pred)[0])

    yo_pred = medim.utils.combine(xo_predicted, n_parts_per_axis)
    return yo_pred