_____

In [1]:
import os
import sys
sys.path.append("../utils/")
sys.path.append("../models/")
import warnings
warnings.filterwarnings("ignore")
from datetime import datetime

import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from models import UNet, UResNet
# from max_model import ENet3D as ENet
from nn_utils import iterate_minibatches
from mulptiprocessing_utils import par_iterate_minibatches
from pytorch_utils import to_numpy, to_var, loss_cross_entropy, stochastic_step
from data_utils import load_files, random_nonzero_crops, augment, _reshape_to
from data_utils import combine, divide
from metrics import hausdorff, dice
from model_controller import Model_controller

%matplotlib inline

In [2]:
PATH = '../data/Utr/Utrecht/'
PATH1 = '../data/Sing/Singapore/'
PATH2 = '../data/Amst/GE3T/'

In [3]:
masks, t1, flairs, brain_mask = load_files(PATH)
brains = np.concatenate([flairs, t1], axis=1).astype(np.float32)

masks1, t11, flairs1, brain_mask1 = load_files(PATH1)
brains1 = np.concatenate([flairs1, t11], axis=1).astype(np.float32)

masks2, t12, flairs2, brain_mask2 = load_files(PATH2)
brains2 = np.concatenate([flairs2, t12], axis=1).astype(np.float32)

  0%|          | 0/20 [00:00<?, ?it/s]

INFO: brain mask from hardcoded path  ../data/skull_stripping


100%|██████████| 20/20 [00:04<00:00,  4.18it/s]


INFO: data reshape to hardcoded shape new_shape=! (256, 256, 84)


  0%|          | 0/20 [00:00<?, ?it/s]

INFO: brain mask from hardcoded path  ../data/skull_stripping


100%|██████████| 20/20 [00:05<00:00,  3.85it/s]


INFO: data reshape to hardcoded shape new_shape=! (256, 256, 84)


  0%|          | 0/20 [00:00<?, ?it/s]

INFO: brain mask from hardcoded path  ../data/skull_stripping


100%|██████████| 20/20 [00:05<00:00,  3.97it/s]


INFO: data reshape to hardcoded shape new_shape=! (256, 256, 84)


_____

____

In [4]:
def get_train_split(get_train_split=20, n=15):
    return np.in1d(np.array(range(get_train_split)), 
                   np.random.choice(range(get_train_split), n, False))

In [5]:
idx1, idx2, idx3 = get_train_split(), get_train_split(), get_train_split()
idx1, idx2, idx3

(array([ True,  True,  True,  True, False,  True,  True,  True,  True,
        False,  True,  True, False, False,  True,  True,  True, False,
         True,  True], dtype=bool),
 array([False,  True, False,  True,  True,  True, False,  True,  True,
        False,  True,  True, False,  True,  True,  True,  True,  True,
         True,  True], dtype=bool),
 array([ True,  True, False,  True,  True,  True,  True, False, False,
         True, False,  True,  True,  True,  True,  True,  True, False,
         True,  True], dtype=bool))

X_train = np.concatenate([brains[idx1]])
X_test = np.concatenate([brains[~idx1]])
Y_train = np.concatenate([masks[idx1]])
Y_test = np.concatenate([masks[~idx1]])

___

In [6]:
from sklearn.model_selection import KFold
skf = KFold(n_splits=4, random_state=0, shuffle=True)

In [None]:
import pickle

In [None]:
counter=0
for tr_idx, ts_idx in skf.split(range(len(brains)), range(len(masks))):
    counter+=1
    X_train = np.concatenate([brains[tr_idx], brains1[tr_idx], brains2[tr_idx]])
    X_test = np.concatenate([brains[ts_idx], brains1[ts_idx], brains2[ts_idx]])
    Y_train = np.concatenate([masks[tr_idx], masks1[tr_idx], masks2[tr_idx]])
    Y_test = np.concatenate([masks[ts_idx], masks1[ts_idx], masks2[ts_idx]])
    

    model = UNet().cuda()
#     model = UResNet().cuda()
#     model = ENet(3,2).cuda()
    controller = Model_controller(model)
    
    controller.init_train_procedure(batch_size = 30, crops_shape=(52, 52, 40),
                                    context_shape=(52, 52, 40), lr=0.015, 
                                    lr_decayer=lambda x: x*0.75, net='3D',
                                    decay_every_epoch=16, num_of_patches=450)
    
    controller.train(X_train, Y_train, epoches=100, X_val=X_test, y_val=Y_test,
                     hard_negative = lambda epoch: False)
    y = controller.predict(X_test)
    
    np.savez('../reports/figures/UntitledFolder/y_dump_'+str(counter)+'.npy', y)
    
    y = np.logical_and(y[:, 1, ...] > y[:, 0, ...],
                       y[:, 1, ...] > y[:, 2, ...])
    
    for num_img in range(len(Y_test)):
        plt.figure()
        t = np.argmax(np.sum(Y_test[num_img, 0, ...]==1, axis=(-2, -3)))
        plt.subplot(121)
        plt.imshow(y[num_img, ..., t], cmap=plt.cm.binary_r)
        plt.title('dice: ' + str(dice(y[num_img], Y_test[num_img,0]))[:5])
        plt.subplot(122)
        plt.imshow(Y_test[num_img, 0,..., t], cmap=plt.cm.binary_r, vmax=2)
        plt.savefig('../reports/figures/UntitledFolder/'+str(counter)+'_'+str(num_img)+'.png')
        
        plt.figure()
        plt.plot(controller.res['train_loss'], label='train')
        plt.plot(controller.res['val_loss'], label='val')
        plt.legend();
        plt.savefig('../reports/figures/UntitledFolder/'+str(counter)+'_tr_val_3D_UNet.png')
        plt.figure()
        plt.plot(controller.res['lr'], label='lr')
        plt.legend();
        plt.savefig('../reports/figures/UntitledFolder/'+str(counter)+'_lr_3D_UNet.png')
        

0.0542: 100%|██████████| 15/15 [00:34<00:00,  2.27s/it]


0 epoch: 
train loss 0.22762407933672268


0.0375: 100%|██████████| 3/3 [00:04<00:00,  1.38s/it]


val loss 0.03499444449941317


0.0895: 100%|██████████| 15/15 [00:36<00:00,  2.36s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

1 epoch: 
train loss 0.08948013409972191


0.0450: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.04373382901151975


0.0474: 100%|██████████| 15/15 [00:36<00:00,  2.37s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

2 epoch: 
train loss 0.055058246354262035


0.0260: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.02641392933825652


0.0354: 100%|██████████| 15/15 [00:36<00:00,  2.37s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

3 epoch: 
train loss 0.04462571255862713


0.0237: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.02657667485376199


0.0417: 100%|██████████| 15/15 [00:36<00:00,  2.38s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

4 epoch: 
train loss 0.045262839272618295


0.0238: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.023785225426157314


0.0406: 100%|██████████| 15/15 [00:36<00:00,  2.37s/it]


5 epoch: 
train loss 0.03884287799398104


0.0261: 100%|██████████| 4/4 [00:05<00:00,  1.34s/it]


val loss 0.023379947524517775


0.0360: 100%|██████████| 15/15 [00:36<00:00,  2.37s/it]


6 epoch: 
train loss 0.03840328343212605


0.0230: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.0208376490821441


0.0419: 100%|██████████| 15/15 [00:36<00:00,  2.37s/it]


7 epoch: 
train loss 0.04123069482545058


0.0182: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.021282924339175224


0.0350: 100%|██████████| 15/15 [00:36<00:00,  2.37s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

8 epoch: 
train loss 0.03798244011898835


0.0168: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.02335016553600629


0.0375: 100%|██████████| 15/15 [00:36<00:00,  2.36s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

9 epoch: 
train loss 0.04112351089715958


0.0202: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.027201912676294644


0.0304: 100%|██████████| 15/15 [00:36<00:00,  2.36s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

10 epoch: 
train loss 0.03670464667181174


0.0158: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.017589645460247993


0.0336: 100%|██████████| 15/15 [00:36<00:00,  2.36s/it]


11 epoch: 
train loss 0.03966647485891978


0.0150: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.01916994558026393


0.0274: 100%|██████████| 15/15 [00:36<00:00,  2.37s/it]


12 epoch: 
train loss 0.03991304847101371


0.0240: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.022894632692138355


0.0384: 100%|██████████| 15/15 [00:36<00:00,  2.36s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

13 epoch: 
train loss 0.03919654736916224


0.0290: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.023559870198369026


0.0329: 100%|██████████| 15/15 [00:36<00:00,  2.36s/it]


14 epoch: 
train loss 0.039921806876858076


0.0186: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.018414750695228577


0.0302: 100%|██████████| 15/15 [00:36<00:00,  2.36s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

15 epoch: 
train loss 0.037795241673787436


0.0252: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.02362031675875187


0.0406: 100%|██████████| 15/15 [00:36<00:00,  2.36s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

16 epoch: 
train loss 0.03890722716848056


0.0194: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.01703173015266657


0.0805: 100%|██████████| 15/15 [00:36<00:00,  2.37s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

17 epoch: 
train loss 0.08026901061336199


0.0389: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.03608658164739609


0.0333: 100%|██████████| 15/15 [00:36<00:00,  2.36s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

18 epoch: 
train loss 0.05269686082998912


0.0244: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.026390161986152332


0.0434: 100%|██████████| 15/15 [00:36<00:00,  2.36s/it]


19 epoch: 
train loss 0.03972859506805738


0.0206: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.022303874293963116


0.0397: 100%|██████████| 15/15 [00:36<00:00,  2.37s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

20 epoch: 
train loss 0.04336763508617878


0.0205: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.01998905713359515


0.0636: 100%|██████████| 15/15 [00:36<00:00,  2.37s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

21 epoch: 
train loss 0.04462376634279887


0.0162: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.020546478529771168


0.0369: 100%|██████████| 15/15 [00:36<00:00,  2.38s/it]


22 epoch: 
train loss 0.041371595859527585


0.0231: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.022385617718100548


0.0456: 100%|██████████| 15/15 [00:36<00:00,  2.37s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

23 epoch: 
train loss 0.037737915913263954


0.0232: 100%|██████████| 3/3 [00:04<00:00,  1.34s/it]


val loss 0.0212338554362456


0.0340: 100%|██████████| 15/15 [00:36<00:00,  2.35s/it]


24 epoch: 
train loss 0.04124908993641536


0.0218: 100%|██████████| 4/4 [00:05<00:00,  1.34s/it]


val loss 0.022224367130547762


0.0540:   7%|▋         | 1/15 [00:02<00:41,  2.98s/it]

___

In [None]:
import pickle
with open('../models/dumps/model_unetcnn.pkl', 'wb') as f:
    pickle.dump(controller.model.state_dict(), f, protocol=pickle.HIGHEST_PROTOCOL)