## Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable
import os
import niwidgets as nw
import nibabel as nib

from dvn.dvn_fcn import DeepVesselNetFCN
from dvn.solver import Solver
from dvn.data_utils import MRAData
from dvn import misc as ms
from dvn import patchify_unpatchify as pu

torch.set_default_tensor_type('torch.FloatTensor')
# set up default cuda device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings(action='ignore')

## MRA dataset

Test obtained model with MRA data by using overlapping patches and then concatenating and unpatchfying

In [2]:
patch_size = 128

# Load all volumes
all_patients = MRAData(root_path="./mra/", patch_size=patch_size)

# Split into train and validation (leave some for test later)
train_mra = all_patients[0:30]
train_loader = torch.utils.data.DataLoader(train_mra, batch_size=2, shuffle=True, num_workers=1)

val_mra = all_patients[30:44]
val_loader = torch.utils.data.DataLoader(val_mra, batch_size=2, shuffle=True, num_workers=1)
print("Done")

Done


### Check data

In [3]:
print("Train size: %i" % len(train_mra))
print("Validation size: %i" % len(val_mra))
print("Img size: ", train_mra[0][0].size())
print("Segmentation size: ", train_mra[0][1].size())

Train size: 30
Validation size: 14
Img size:  torch.Size([1, 128, 128, 128])
Segmentation size:  torch.Size([128, 128, 128])


## Load model and train

In [4]:
model = DeepVesselNetFCN(batchnorm=True, dropout=True)
model = torch.load("models/deepvesselnet_final_batchnorm+dropout.model")

optim_args_SGD = {"lr": 2e-2, "weight_decay": 0.0005, "momentum": 0.9, "nesterov": True}

solver = Solver(optim_args=optim_args_SGD, optim=torch.optim.SGD)
solver.train(model, train_loader, val_loader, log_nth=5, num_epochs=50)

START TRAIN
[Iteration 5/750] TRAIN loss: -0.265
[Iteration 10/750] TRAIN loss: -1.376
[Iteration 15/750] TRAIN loss: -0.006
[Epoch 1/50] TRAIN acc/loss/dice: 0.985/-0.006/0.496
[Epoch 1/50] VAL   acc/loss/dice: 0.460/0.000/0.496
[Iteration 20/750] TRAIN loss: 0.000
[Iteration 25/750] TRAIN loss: -0.127
[Iteration 30/750] TRAIN loss: -2.136
[Epoch 2/50] TRAIN acc/loss/dice: 0.938/-2.136/0.484
[Epoch 2/50] VAL   acc/loss/dice: 0.880/0.000/0.484
[Iteration 35/750] TRAIN loss: 0.000
[Iteration 40/750] TRAIN loss: -2.092
[Iteration 45/750] TRAIN loss: -0.201
[Epoch 3/50] TRAIN acc/loss/dice: 0.923/-0.201/0.498
[Epoch 3/50] VAL   acc/loss/dice: 0.878/0.000/0.498
[Iteration 50/750] TRAIN loss: -0.224
[Iteration 55/750] TRAIN loss: 0.000
[Iteration 60/750] TRAIN loss: -2.133
[Epoch 4/50] TRAIN acc/loss/dice: 0.970/-2.133/0.663
[Epoch 4/50] VAL   acc/loss/dice: 0.744/0.000/0.663
[Iteration 65/750] TRAIN loss: -0.178
[Iteration 70/750] TRAIN loss: -1.988
[Iteration 75/750] TRAIN loss: -0.225
[E

[Iteration 560/750] TRAIN loss: -0.232
[Iteration 565/750] TRAIN loss: -0.177
[Iteration 570/750] TRAIN loss: -2.018
[Epoch 38/50] TRAIN acc/loss/dice: 0.971/-2.018/0.493
[Epoch 38/50] VAL   acc/loss/dice: 0.735/0.000/0.493
[Iteration 575/750] TRAIN loss: 0.000
[Iteration 580/750] TRAIN loss: -0.184
[Iteration 585/750] TRAIN loss: -2.216
[Epoch 39/50] TRAIN acc/loss/dice: 0.954/-2.216/0.488
[Epoch 39/50] VAL   acc/loss/dice: 0.585/0.000/0.488
[Iteration 590/750] TRAIN loss: -2.178
[Iteration 595/750] TRAIN loss: -0.233
[Iteration 600/750] TRAIN loss: -0.007
[Epoch 40/50] TRAIN acc/loss/dice: 0.933/-0.007/0.482
[Epoch 40/50] VAL   acc/loss/dice: 0.729/0.000/0.482
[Iteration 605/750] TRAIN loss: -2.247
[Iteration 610/750] TRAIN loss: -0.174
[Iteration 615/750] TRAIN loss: 0.000
[Epoch 41/50] TRAIN acc/loss/dice: 0.942/0.000/0.489
[Epoch 41/50] VAL   acc/loss/dice: 0.796/0.000/0.489
[Iteration 620/750] TRAIN loss: -0.234
[Iteration 625/750] TRAIN loss: -2.166
[Iteration 630/750] TRAIN los

## Save model

In [5]:
model.save("models/deepvesselnet_mra_final_pretrain.model")


Saving model... models/deepvesselnet_mra_final_pretrain.model


In [6]:
# model = DeepVesselNetFCN(batchnorm=True, dropout=True)
# model = torch.load("models/deepvesselnet_mra_final_no_pretrain.model")
all_patients_test = MRAData(root_path="./mra/", mode="test", transform="none")

test_num = 3
model.eval()
for i in range(test_num):
    volume, segmentation = all_patients_test[i]
    volume = volume.unsqueeze(0)
    MRA_path = all_patients_test.name

    output = ms.test(model, volume)

    dice = ms.dice_coeff(output, segmentation, pred=True).detach().cpu().numpy()
    print("Dice coefficient of output: ", dice)
    print("Num seg pixels: ", np.argwhere(segmentation.detach().cpu().numpy() == 1).size)
    print("Num output pixels: ", np.argwhere(output.detach().cpu().numpy() == 1).size)

#     model.writer("Test dice coefficient", dice, i)
    
print("FINISH TEST")

Dice coefficient of output:  0.000363429
Num seg pixels:  4470
Num output pixels:  169813137
Dice coefficient of output:  0.001598643
Num seg pixels:  172056
Num output pixels:  169769484
Dice coefficient of output:  0.00055015384
Num seg pixels:  10563
Num output pixels:  169787445
FINISH TEST


## Visualize MRA and output

In [7]:
MRA_affine = nib.load(MRA_path).affine

MRA_widget = nw.NiftiWidget(MRA_path)
MRA_widget.nifti_plotter()

seg_path = MRA_path.replace("raw", "seg")
seg_widget = nw.NiftiWidget(seg_path)
seg_widget.nifti_plotter()

save_name = MRA_path.replace("./mra/raw/", "")
out_img = nib.Nifti1Image(output.detach().cpu().numpy(), MRA_affine)
nib.save(out_img, os.path.join("mra_out", save_name))

test_widget = nw.NiftiWidget(os.path.join("mra_out", save_name))
test_widget.nifti_plotter()

<Figure size 432x288 with 0 Axes>

interactive(children=(IntSlider(value=191, continuous_update=False, description='x', max=383), IntSlider(value…

<Figure size 432x288 with 0 Axes>

interactive(children=(IntSlider(value=191, continuous_update=False, description='x', max=383), IntSlider(value…

<Figure size 432x288 with 0 Axes>

interactive(children=(IntSlider(value=191, continuous_update=False, description='x', max=383), IntSlider(value…