Imports
============

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from dvn.dvn_fcn import DeepVesselNetFCN
from dvn.solver import Solver
from dvn.data_utils import SyntheticData

torch.set_default_tensor_type('torch.FloatTensor')
# set up default cuda device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

Firstly we load the data
============


In [2]:
patch_size=64

train_synthetic = SyntheticData(root_path="./data/train/", patch_size=patch_size)
val_synthetic = SyntheticData(root_path="./data/val/", patch_size=patch_size)

print("DONE")

DONE


Visualize part of the data
============

In [3]:
print("Train size: %i" % len(train_synthetic))
print("Validation size: %i" % len(val_synthetic))
print("Img size: ", train_synthetic[0][0].size())
print("Segmentation size: ", train_synthetic[0][1].size())

Train size: 110
Validation size: 26
Img size:  torch.Size([1, 64, 64, 64])
Segmentation size:  torch.Size([64, 64, 64])


Design network 
============

In [4]:
train_loader = torch.utils.data.DataLoader(train_synthetic, batch_size=10, shuffle=True, num_workers=1)
val_loader = torch.utils.data.DataLoader(val_synthetic, batch_size=2, shuffle=False, num_workers=1)

model = DeepVesselNetFCN()
model.to(device)

DeepVesselNetFCN(
  (conv1): Conv3d(1, 5, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv2): Conv3d(5, 10, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
  (conv3): Conv3d(10, 20, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
  (conv4): Conv3d(20, 50, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (fcn1): Conv3d(50, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (softmax): Softmax(dim=1)
  (relu): ReLU(inplace=True)
  (sigmoid): Sigmoid()
)

Train network
============

In [None]:
solver = Solver(optim_args={"lr": 0.001, "weight_decay": 0.1})
solver.train(model, train_loader, val_loader, log_nth=5, num_epochs=10)


START TRAIN
[Iteration 5/110] TRAIN loss: 0.957
[Iteration 10/110] TRAIN loss: 0.958
[Epoch 1/10] TRAIN acc/loss: 0.947/0.958


Test network
============

Visualization of network outputs
============


In [None]:
num_example_imgs = 3
plt.figure(figsize=(15, 5 * num_example_imgs))
for i, (inputs, targets) in enumerate(train_loader[:num_example_imgs]):
    inputs, targets = inputs.to(device, dtype=torch.float), targets.to(device, dtype=torch.long)
    
    outputs = model.forward(inputs)
    _, preds = torch.max(outputs, 1)
    pred = preds[0].data.cpu()
    print(preds.shape)
    print(pred.shape)

    inputs, targets, pred = inputs.cpu().numpy(), targets.cpu().numpy(), pred.cpu().numpy()

    # inputs
    plt.subplot(num_example_imgs, 3, i * 3 + 1)
    plt.axis('off')
    plt.imshow(inputs[0,0,5])
    if i == 0:
        plt.title("Input image")
    
    # target
    plt.subplot(num_example_imgs, 3, i * 3 + 2)
    plt.axis('off')
    plt.imshow(targets[0,5])
    if i == 0:
        plt.title("Target image")

    # pred
    plt.subplot(num_example_imgs, 3, i * 3 + 3)
    plt.axis('off')
    plt.imshow(pred[5])
    if i == 0:
        plt.title("Prediction image")