In [1]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline

import time
import pickle
from imp import reload
from os.path import join

from sklearn.model_selection import KFold
import pandas as pd
import seaborn as sns
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from data_loader import DataLoader
import medim
reload(medim);

In [8]:
torch.cuda.

('64bit', 'ELF')

In [2]:
raw_path = '/home/mount/neuro-t01-hdd/Brats2017/data/raw/'

data_loader = DataLoader(raw_path)
patients = data_loader.patients

n_classes = 3

In [3]:
processed_path = '/mount/export/Brats2017/data/processed'
mscans = []
msegms = []

for patient in tqdm(patients):
    filename = join(processed_path, patient)
    
    mscans.append(np.load(filename+'_mscan.npy'))
    msegms.append(np.load(filename+'_segmentation.npy'))

100%|██████████| 285/285 [00:46<00:00,  3.55it/s]


In [4]:
n_splits = 50

cv = KFold(n_splits, shuffle=True, random_state=17)
train, val = next(cv.split(mscans))

def extract(l, idx):
    return [l[i] for i in idx]

mscans_train, mscans_val = extract(mscans, train), extract(mscans, val)
msegms_train, msegms_val = extract(msegms, train), extract(msegms, val)

In [10]:
def build_model(n_chans_for_each_layer, kernel_size):
    fe = []
    n_chans_prev = 4
    for n_chans in n_chans_for_each_layer[:-1]:
        c = nn.Conv3d(n_chans_prev, n_chans, kernel_size, bias=False)
        bn = nn.BatchNorm3d(n_chans)
        a = nn.ReLU()
        fe.extend([c, bn, a])
        n_chans_prev = n_chans
        
    n_chans = n_chans_for_each_layer[-1]
    c = nn.Conv3d(n_chans_prev, n_chans, kernel_size=1, bias=False)
    bn = nn.BatchNorm3d(n_chans)
    fe.extend([c, bn])

    return nn.Sequential(*fe)


class Model(torch.nn.Module):
    def __init__(self, n_chans_for_each_layer, kernel_size):
        super().__init__()
        self.model = build_model(n_chans_for_each_layer, kernel_size)
    
    def forward(self, input):
        model = self.model(input)
        return model

kernel_size = 3
n_chans_for_each_layer = [16, 16, 32, 32, 64, 64, 8, n_classes]

patch_size_x = np.array([25, 25, 25])
patch_size_y = patch_size_x - 2*(len(n_chans_for_each_layer) - 1)

model = Model(n_chans_for_each_layer, kernel_size).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model

Model (
  (model): Sequential (
    (0): Conv3d(4, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), bias=False)
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU ()
    (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), bias=False)
    (4): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU ()
    (6): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), bias=False)
    (7): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU ()
    (9): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), bias=False)
    (10): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True)
    (11): ReLU ()
    (12): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), bias=False)
    (13): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True)
    (14): ReLU ()
    (15): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), bias=False)
    (16): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True)
    (17): ReLU ()
    (18): Conv3d(64

In [7]:
padding = (patch_size_x - patch_size_y) // 2

def min_padding(mscan, padding):
    padding = np.array([0] + list(padding))
    padding = np.repeat(padding[:, None], 2, axis=1)
    
    return np.pad(mscan, padding, mode='minimum')

In [11]:
def pred_reshape(y):
    x = y.permute(0, 2, 3, 4, 1)
    return x.contiguous().view(-1, x.size()[-1])

def loss_cross_entropy(y_pred, y_true):
    return F.cross_entropy(pred_reshape(y_pred), y_true.view(-1))

def loss_binary_entropy(y_pred, y_true):
    return F.binary_cross_entropy(y_pred, y_true)

def to_var(x, volatile=False):
    return Variable(torch.from_numpy(x), volatile=volatile).cuda()

def to_numpy(x):
    return x.cpu().data.numpy()

coeff = to_var(np.array([1, 2, 3], dtype=np.float32))
epsilon = 1e-7

def dice_loss(y_pred, target):
    y_pred = y_pred.view(*y_pred.size()[:2], -1)
    target = target.view(*target.size()[:2], -1)
    
#     s = y_pred.size()
#     e = epsilon.expand(s[0], 1, s[2])
    dice_scores = 2 * (epsilon + (y_pred * target).sum(2)) / \
                  (y_pred.sum(2) + target.sum(2) + 2 * epsilon)
        
    dice_scores = dice_scores.mean(0)
    dice_scores = dice_scores.view(-1)

    return -torch.sum(dice_scores * coeff)

def dice_score(y_pred, target):
    """Dice score for binary segmentation on 3d scan"""
    return 2 * np.sum(y_pred * target) / (np.sum(y_pred) + np.sum(target))

In [12]:
n_epoch = 100
batch_per_epoch = 40
batch_size = 128

for epoch in range(n_epoch):
    train_iter = medim.batch_iter.patch.uniform(
        mscans_train, msegms_train, batch_size=batch_size,
        patch_size_x=patch_size_x, patch_size_y=patch_size_y, 
    )
    
    start_train = time.time()
    
    model.train()
    losses = []
    weights = []
    for _ in range(batch_per_epoch):
        x_batch, y_batch = next(train_iter)

        y_pred = model(to_var(x_batch))

        optimizer.zero_grad()
        loss = loss_binary_entropy(y_pred, to_var(y_batch))
        
        loss.backward()
        optimizer.step()

        losses.append(to_numpy(loss))
        weights.append(len(x_batch))
            
    train_loss = np.average(np.array(losses).flatten(), weights=weights)
    #train_acc = np.average(np.array(accs).flatten(), weights=weights)
    
    end_train = time.time()
    
    start_val = time.time()
    
    model.eval()
    losses = []
    dices = []
    for mscan, segm in tqdm(zip(mscans_val, msegms_val)):
        x_batch = min_padding(mscan, padding)[None, :]
        y_batch = np.array(segm[None, :], dtype=np.float32)
        y_pred = model(to_var(x_batch, volatile=True))

        loss = loss_binary_entropy(y_pred, to_var(y_batch))

        y_pred = to_numpy(y_pred)
        dices.append([dice_score(y_pred[0, k] > 0.5, y_batch[0, k]) for k in range(n_classes)])
        losses.append(to_numpy(loss))
    
    end_val = time.time()
    
    val_loss = np.mean(np.array(losses).flatten())
    val_dices = np.mean(np.array(dices), axis=0)
    
    print('Epoch {}'.format(epoch))
    print('Train:', train_loss)
    print('Val  :', val_loss, val_dices)
    print('Time :', end_train - start_train, end_val - start_val)
    print('\n')

6it [00:08,  1.46s/it]


Epoch 0
Train: nan
Val  : nan [ 0.19910945  0.40198349  0.02444046]
Time : 14.028355360031128 8.625139236450195




KeyboardInterrupt: 

In [12]:
def dice_score(y_pred, target):
    intersec = np.sum(y_pred == target)
    return 2 * intersec / 
    
    

SyntaxError: invalid syntax (<ipython-input-12-023c72b2002e>, line 3)

In [13]:
plt.imshow(to_numpy(model(to_var(x_batch, volatile=True)))[0, 0, ..., 10])

RuntimeError: cuda runtime error (2) : out of memory at /py/conda-bld/pytorch_1493680494901/work/torch/lib/THC/THCGeneral.c:833

In [133]:
np.sum(e == np.zeros_like(e)) / np.prod(e.shape)

1.0

In [131]:
e = np.argmax(to_numpy(pred_reshape(y_pred)), axis=1)

In [134]:
e

array([0, 0, 0, ..., 0, 0, 0])

In [55]:
def one_hot(x):
    enc = np.eye(n_classes)[x]
    return np.rollaxis(enc, 4, 1)

In [62]:
np.unique(y_batch, return_counts=True)

(array([0, 2]), array([6621,  669]))

In [67]:
np.all(one_hot(y_batch).sum(1) == 1)

True

In [113]:
a = one_hot(y_batch)

In [114]:
np.mean(np.argmax(a, axis=1) == y_batch)

1.0

In [112]:
a.shape

(10, 4, 9, 9, 9)

In [115]:
b = pred_reshape(to_var(a))

In [116]:
b.size()

torch.Size([7290, 4])

In [123]:
F.cross_entropy(b*10, to_var(y_batch.flatten()).long())

Variable containing:
1.00000e-04 *
  1.3619
[torch.cuda.DoubleTensor of size 1 (GPU 0)]

In [8]:
def to_vec(y):
    return y.reshape((1,))

In [79]:
y = np.arange(8).reshape((2,2,2))

In [80]:
x = torch.randn(2, 3, 5)

In [46]:
s = x.size()

In [37]:
a = x.permute(2, 0, 1).size()

In [54]:
s

torch.Size([2, 3, 5])

In [43]:
np.prod(a[:2])

10

In [52]:
x = x.view(int(np.prod((x.size()[:2]))), x.size()[-1])

In [53]:
x.size()

torch.Size([6, 5])