Imports
============

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from dvn.dvn_fcn import DeepVesselNetFCN
from dvn.solver import Solver
from dvn.data_utils import SyntheticData

torch.set_default_tensor_type('torch.FloatTensor')
# set up default cuda device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [2]:
torch.cuda.get_device_name(0)

'TITAN Xp'

Firstly we load the data
============


In [3]:
patch_size=64

train_synthetic = SyntheticData(root_path="./data/train/", patch_size=patch_size)
val_synthetic = SyntheticData(root_path="./data/val/", patch_size=patch_size)

print("DONE")

DONE


Visualize part of the data
============

In [4]:
print("Train size: %i" % len(train_synthetic))
print("Validation size: %i" % len(val_synthetic))
print("Img size: ", train_synthetic[0][0].size())
print("Segmentation size: ", train_synthetic[0][1].size())

Train size: 110
Validation size: 26
Img size:  torch.Size([1, 64, 64, 64])
Segmentation size:  torch.Size([64, 64, 64])


Design network 
============

In [5]:
train_loader = torch.utils.data.DataLoader(train_synthetic, batch_size=10, shuffle=True, num_workers=1)
val_loader = torch.utils.data.DataLoader(val_synthetic, batch_size=2, shuffle=False, num_workers=1)

Train network
============

In [33]:
model = DeepVesselNetFCN()
model.to(device)

optim_args_SGD = {"lr": 1e-2, "weight_decay": 0.001}
optim_args_ADAM = {"lr": 1e-2, "weight_decay": 0.001}

solver = Solver(optim_args=optim_args_ADAM, optim=torch.optim.Adam)
solver.train(model, train_loader, val_loader, log_nth=1, num_epochs=1)

START TRAIN
None


AttributeError: 'NoneType' object has no attribute 'abs'

Test network
============

In [21]:
print(model.conv1.weight.grad)

None


Visualization of network outputs
============


In [None]:
num_example_imgs = 3
plt.figure(figsize=(15, 5 * num_example_imgs))
for i, (inputs, targets) in enumerate(val_synthetic[:num_example_imgs]):
    inputs, targets = inputs.to(device, dtype=torch.float), targets.to(device)
    
    outputs = model.forward(inputs)
    _, preds = torch.max(outputs, 1)

    inputs, targets, preds = inputs.cpu().numpy(), targets.cpu().numpy(), preds.cpu().numpy()

    # inputs
    plt.subplot(num_example_imgs, 3, i * 3 + 1)
    plt.axis('off')
    plt.imshow(inputs[0,0,5])
    if i == 0:
        plt.title("Input image")
    
    # target
    plt.subplot(num_example_imgs, 3, i * 3 + 2)
    plt.axis('off')
    plt.imshow(targets[0,5])
    if i == 0:
        plt.title("Target image")

    # pred
    plt.subplot(num_example_imgs, 3, i * 3 + 3)
    plt.axis('off')
    plt.imshow(pred[0,5])
    if i == 0:
        plt.title("Prediction image")