In [22]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline

import time
import pickle
from imp import reload
from os.path import join

import pandas as pd
import seaborn as sns
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from data_loader import DataLoader
import medim
reload(medim);

In [3]:
raw_path = '/home/mount/neuro-t01-hdd/Brats2017/data/raw/'

data_loader = DataLoader(raw_path)
patients = data_loader.patients

In [4]:
processed_path = '/mount/export/Brats2017/data/processed'
mscans = []
segmentations = []

for patient in tqdm(patients):
    filename = join(processed_path, patient)
    
    mscans.append(np.load(filename+'_mscan.npy'))
    segmentations.append(np.load(filename+'_segmentation.npy'))

100%|██████████| 285/285 [00:50<00:00,  4.48it/s]


In [139]:

#indices = np.arange(len(inputs))
#np.random.shuffle(indices)

(4, 140, 172, 145)

In [125]:
patch_size_x = np.array([25, 25, 25])
patch_size_y = np.array([9, 9, 9])

assert np.all(patch_size_x % 2 == patch_size_y % 2)
patch_size_pad = (patch_size_y - patch_size_y) // 2

n_mods = 4
n_classes = 4

def make_unif_batch_iter(mscans, segms, batch_size):    
    n = len(mscans)
    max_spatial_idx = np.array([list(s.shape[1:]) for s in mscans]) - patch_size_x + 1
    
    x_batch = np.zeros((batch_size, n_mods, *patch_size_x))
    y_batch = np.zeros((batch_size, *patch_size_y), dtype=np.int64)
    
    while True:
        idx = np.random.randint(n, size=batch_size)
        start_idx = np.random.rand(batch_size, 3) * max_spatial_idx[idx]
        start_idx = np.int32(np.floor(start_idx))
        for i in range(batch_size):
            s = start_idx[i]
            slices = [...] + [slice(s[k], s[k]+patch_size_x[k]) for k in range(3)]
            x_batch[i] = mscans[idx[i]][slices]
            
            s = start_idx[i] + patch_size_pad
            slices = [slice(s[k], s[k]+patch_size_y[k]) for k in range(3)]
            y_batch[i] = segms[idx[i]][slices]
        yield np.array(x_batch, dtype=np.float32), np.array(y_batch)

In [135]:
def build_model(n_chans_for_each_layer, kernel_size):
    fe = []
    n_chans_prev = 4
    for n_chans in n_chans_for_each_layer[:-1]:
        c = nn.Conv3d(n_chans_prev, n_chans, kernel_size, bias=False)
        bn = nn.BatchNorm3d(n_chans)
        a = nn.ReLU()
        fe.extend([c, bn, a])
        n_chans_prev = n_chans
        
    n_chans = n_chans_for_each_layer[-1]
    c = nn.Conv3d(n_chans_prev, n_chans, kernel_size=1, bias=False)
    bn = nn.BatchNorm3d(n_chans)
    #a = nn.Softmax()
    fe.extend([c, bn])

    return nn.Sequential(*fe)


class Model(torch.nn.Module):
    def __init__(self, n_chans_for_each_layer, kernel_size):
        super().__init__()
        self.model = build_model(n_chans_for_each_layer, kernel_size)
    
    def forward(self, input):
        model = self.model(input)
        return model

kernel_size = 5
n_chans_for_each_layer = [30, 40, 40, 50, 4]

model = Model(n_chans_for_each_layer, kernel_size).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model

Model (
  (model): Sequential (
    (0): Conv3d(4, 30, kernel_size=(5, 5, 5), stride=(1, 1, 1), bias=False)
    (1): BatchNorm3d(30, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU ()
    (3): Conv3d(30, 40, kernel_size=(5, 5, 5), stride=(1, 1, 1), bias=False)
    (4): BatchNorm3d(40, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU ()
    (6): Conv3d(40, 40, kernel_size=(5, 5, 5), stride=(1, 1, 1), bias=False)
    (7): BatchNorm3d(40, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU ()
    (9): Conv3d(40, 50, kernel_size=(5, 5, 5), stride=(1, 1, 1), bias=False)
    (10): BatchNorm3d(50, eps=1e-05, momentum=0.1, affine=True)
    (11): ReLU ()
    (12): Conv3d(50, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (13): BatchNorm3d(4, eps=1e-05, momentum=0.1, affine=True)
  )
)

In [7]:
def pred_reshape(y):
    x = y.permute(0, 2, 3, 4, 1)
    return x.contiguous().view(-1, x.size()[-1])

def to_var(x):
    return Variable(torch.from_numpy(x))

def to_numpy(x):
    return x.cpu().data.numpy()

In [136]:
n_epoch = 100
batch_per_epoch = 40
batch_size = 128


for epoch in range(n_epoch):
    
    train_iter = make_unif_batch_iter(mscans, segmentations, batch_size)
    val_iter = make_unif_batch_iter(mscans, segmentations, batch_size)
    
    start = time.time()
    
    model.train()
    losses = []
    weights = []
    accs = []
    
    for _ in range(batch_per_epoch):
        x_batch, y_batch = next(train_iter)

        y_pred = model(to_var(x_batch))

        optimizer.zero_grad()
        loss = F.cross_entropy(pred_reshape(y_pred), to_var(y_batch.flatten()).long())
        loss.backward()
        optimizer.step()

        acc = np.mean(y_batch.flatten() == np.argmax(to_numpy(pred_reshape(y_pred)), axis=1))
        accs.append(acc)
        losses.append(to_numpy(loss))
        weights.append(len(x_batch))
            
    train_loss = np.average(np.array(losses).flatten(), weights=weights)
    train_acc = np.average(np.array(accs).flatten(), weights=weights)
    
    model.eval()
    losses = []
    weights = []
    accs = []
    
    for _ in range(batch_per_epoch):
        x_batch, y_batch = next(val_iter)
        y_pred = model(to_var(x_batch))

        loss = F.cross_entropy(pred_reshape(y_pred), to_var(y_batch.flatten()).long())

        acc = np.mean(y_batch.flatten() == np.argmax(to_numpy(pred_reshape(y_pred)), axis=1))
        accs.append(acc)
        losses.append(to_numpy(loss))
        weights.append(len(x_batch))
    
    end = time.time()
    
    val_loss = np.average(np.array(losses).flatten(), weights=weights)
    val_acc = np.average(np.array(accs).flatten(), weights=weights)
    
    print('Epoch {}'.format(epoch))
    print('Train:', train_loss, train_acc)
    print('Val  :', val_loss, val_acc)
    print('Time :', end - start)
    print('\n')

Epoch 0
Train: 1.30008297861 0.806465942215
Val  : 1.26333920658 0.862997792353
Time : 32.0606734752655


Epoch 1
Train: 1.20353299379 0.902681595079
Val  : 1.2509970814 0.844320934071
Time : 31.98053741455078


Epoch 2
Train: 1.14871447682 0.905302104767
Val  : 1.21247553527 0.729046371313
Time : 31.961509704589844


Epoch 3
Train: 1.10454013646 0.904559435014
Val  : 1.04486145079 0.927730356224
Time : 31.928505897521973


Epoch 4
Train: 1.02997083515 0.934721686385
Val  : 0.983822830021 0.93451565715
Time : 31.971956968307495


Epoch 5
Train: 0.965610679984 0.931845046725
Val  : 0.928353792429 0.950739990569
Time : 31.958980083465576


Epoch 6
Train: 0.942370587587 0.932113233025
Val  : 0.965880177915 0.934508155436
Time : 31.95565629005432


Epoch 7
Train: 0.880980320275 0.938540327075
Val  : 0.847003127635 0.942912219222
Time : 31.957747220993042


Epoch 8
Train: 0.836619347334 0.942498017404
Val  : 0.874436448514 0.938904428155
Time : 31.955031156539917


Epoch 9
Train: 0.81980569

KeyboardInterrupt: 

In [3]:
def dice_score(y_pred, target):
    """Dice score for binary segmentation on 3d scan"""
    return 2 * np.sum(y_pred * target) / (np.sum(y_pred) + np.sum(target))

In [103]:
def dice_loss(y_pred, target):
    
    def sum(array, dim):
        array = array.sum(dim)
        return array.view(*array.size()[:dim])
    
    coeff = torch.from_numpy(np.array([1, 2, 3], dtype=np.float32))
    
    y_pred = y_pred.view(*y_pred.size()[:2], -1)
    target = target.view(*target.size()[:2], -1)
    
    dice_scores = 2 * sum(y_pred * target, 2) / \
                  (sum(y_pred, 2) + sum(target, 2))
        
    dice_scores = dice_scores.mean(0)
    dice_scores = dice_scores.view(dice_scores.size()[1])

    return torch.sum(coeff * dice_scores)

In [None]:
dice_loss(to_car)

In [123]:
tmp = np.float32(np.random.randint(2, size=(10, 3, 8, 8, 8)))
tmp

array([[[[[ 1.,  0.,  0., ...,  0.,  0.,  0.],
          [ 0.,  0.,  1., ...,  0.,  1.,  0.],
          [ 0.,  0.,  0., ...,  1.,  0.,  1.],
          ..., 
          [ 1.,  1.,  1., ...,  0.,  0.,  0.],
          [ 1.,  0.,  0., ...,  1.,  1.,  0.],
          [ 1.,  0.,  1., ...,  1.,  1.,  1.]],

         [[ 1.,  1.,  0., ...,  1.,  1.,  1.],
          [ 0.,  0.,  0., ...,  1.,  1.,  0.],
          [ 1.,  0.,  1., ...,  1.,  0.,  0.],
          ..., 
          [ 0.,  0.,  1., ...,  1.,  0.,  1.],
          [ 1.,  1.,  1., ...,  0.,  1.,  1.],
          [ 0.,  1.,  1., ...,  1.,  0.,  1.]],

         [[ 0.,  1.,  1., ...,  0.,  0.,  0.],
          [ 1.,  1.,  0., ...,  0.,  1.,  1.],
          [ 1.,  0.,  1., ...,  1.,  1.,  0.],
          ..., 
          [ 1.,  0.,  0., ...,  0.,  0.,  0.],
          [ 0.,  1.,  0., ...,  1.,  0.,  0.],
          [ 1.,  0.,  0., ...,  0.,  0.,  1.]],

         ..., 
         [[ 0.,  0.,  0., ...,  1.,  1.,  0.],
          [ 1.,  0.,  1., ...,  1.,  0

In [127]:
dice_loss(Variable(torch.from_numpy((tmp))), torch.from_numpy((tmp)))

AssertionError: 

In [108]:
example = torch.rand(1, 3, 2, 2, 2)

In [76]:
example.size()

torch.Size([1, 3, 2, 2, 2])

In [77]:
example = example.view(*example.size()[:2], -1)

In [78]:
example.size()

torch.Size([1, 3, 8])

In [85]:
example = example.sum(2)

In [86]:
example = sum(example, 2)

In [87]:
example.size()

torch.Size([1, 3])

In [69]:
example = example.mean(0)

In [88]:
example = example.view(*example.size()[:2])

In [72]:
example[0]


 3.3946
 4.8465
 3.9411
[torch.FloatTensor of size 3]

In [89]:
example.size()

torch.Size([1, 3])

In [90]:
example * example


 13.2408  16.8009  29.3152
[torch.FloatTensor of size 1x3]

In [104]:
dice_loss(example, example)

28.07962679862976

In [25]:
a = torch.cuda.device(0)

In [30]:
!ls /mount/export

animals  Brats2017  dsb2017  ndsb2017
