# Imports

Import all required libraries in this cell

In [3]:
import os
import math
import numpy as np
import nibabel as nib
import niwidgets as nw
import matplotlib.pyplot as plt

from itertools import product
from collections import OrderedDict
from sklearn.metrics import f1_score, precision_recall_curve

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable

import torchio as tio

# set up default cuda device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:', device)

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings(action='ignore')


  from .autonotebook import tqdm as notebook_tqdm


device: cuda:0


Define necessary variables related to train dataset and val dataset

In [4]:
dataset_path = 'dataset'
training_set_path = os.path.join(dataset_path, 'train')
print('training_set_path:', training_set_path)
validation_set_path = os.path.join(dataset_path, 'val')
print('validation_set_path:', validation_set_path)

training_images_path = os.path.join(training_set_path, 'images')
print('training_images_path:', training_images_path)
training_labels_path = os.path.join(training_set_path, 'labels')
print('training_labels_path:', training_labels_path)
training_images = os.listdir(training_images_path)
print('training_images:', training_images)
training_labels = os.listdir(training_labels_path)
print('training_labels:', training_labels)

validation_images_path = os.path.join(validation_set_path, 'images')
print('validation_images_path:', validation_images_path)
validation_labels_path = os.path.join(validation_set_path, 'labels')
print('validation_labels_path:', validation_labels_path)
validation_images = os.listdir(validation_images_path)
print('validation_images:', validation_images)
validation_labels = os.listdir(validation_labels_path)
print('validation_labels:', validation_labels)


training_set_path: dataset\train
validation_set_path: dataset\val
training_images_path: dataset\train\images
training_labels_path: dataset\train\labels
training_images: ['liver1.nii', 'liver10.nii', 'liver11.nii', 'liver12.nii', 'liver13.nii', 'liver14.nii', 'liver15.nii', 'liver16.nii', 'liver17.nii', 'liver18.nii', 'liver2.nii', 'liver3.nii', 'liver4.nii', 'liver5.nii', 'liver6.nii', 'liver7.nii', 'liver8.nii', 'liver9.nii']
training_labels: ['liver1.nii', 'liver10.nii', 'liver11.nii', 'liver12.nii', 'liver13.nii', 'liver14.nii', 'liver15.nii', 'liver16.nii', 'liver17.nii', 'liver18.nii', 'liver2.nii', 'liver3.nii', 'liver4.nii', 'liver5.nii', 'liver6.nii', 'liver7.nii', 'liver8.nii', 'liver9.nii']
validation_images_path: dataset\val\images
validation_labels_path: dataset\val\labels
validation_images: ['liver19.nii', 'liver20.nii']
validation_labels: ['liver19.nii', 'liver20.nii']


# Define Custom Dataset

## SyntheticData2 (custom)

In [5]:
class SyntheticData2(data.Dataset):
    """ 
    Class defined to handle the synthetic dataset
    derived from pytorch's Dataset class.
    """

    def __init__(self, root_path, patch_size=64, mode='train'):
        self.TAG = '[SyntheticData2]'
        self.root_dir_name = root_path
        self.raw_dir_name = os.path.join(self.root_dir_name, 'images')
        self.seg_dir_name = os.path.join(self.root_dir_name, 'labels')
        self.mode = mode
        self.patch_size = patch_size
        print(self.TAG, '[raw_dir_name]', self.raw_dir_name)
        print(self.TAG, '[seg_dir_name]', self.seg_dir_name)
        
        # Finding first and last file names in directory
        self.file_names = os.listdir(self.raw_dir_name)
        self.file_names.sort(key=lambda name: int(name[5:-4]))

        self.transforms = tio.Compose([
            # tio.OneOf({
            #     tio.RandomAffine(): 0.5,
            #     tio.RandomElasticDeformation(): 0.5,
            # }, p=0.5),
            # tio.RandomFlip(axes=(0, 1, 2)),
            tio.ZNormalization(),
        ])

        self.patches_raw = []
        self.patches_seg = []
        for image_name in self.file_names:
            raw, seg = self.get_patches_from_nii(image_name)
            if self.mode == 'test':
                # print(f'[{self.mode}]', image_name, raw.shape, seg.shape)
                raw, seg = raw.unsqueeze(0), seg.unsqueeze(0)
            self.patches_raw += raw
            self.patches_seg += seg
        
        print(self.TAG, '[patches_raw]', len(self.patches_raw), self.patches_raw[0].shape)
        print(self.TAG, '[patches_seg]', len(self.patches_seg), self.patches_seg[0].shape)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
            return [self[i] for i in range(index)]
        elif isinstance(index, slice):
            return [self[i] for i in range(*index.indices(len(self)))]
        elif isinstance(index, int):
            raw = self.patches_raw[index]
            seg = self.patches_seg[index]
            if self.mode == 'test':
                # print(self.TAG, '[__getitem__]', raw.shape, seg.shape)
                raw = raw.unsqueeze(0)
            raw = self.transforms(raw)
            return raw, seg
        else:
            raise TypeError('Invalid argument type.')

    def __len__(self):
        return len(self.patches_raw)

    def get_image_path(self, index):
        return (os.path.join(self.raw_dir_name, self.file_names[index]), 
                os.path.join(self.seg_dir_name, self.file_names[index]))

    def get_patches_from_nii(self, image_name):
        raw_img_name = os.path.join(self.raw_dir_name, image_name)
        seg_img_name = os.path.join(self.seg_dir_name, image_name)
        
        # Load proxy so image not loaded into memory
        raw_proxy = nib.load(raw_img_name)
        seg_proxy = nib.load(seg_img_name)

        # Get dataobj of proxy
        raw_data = np.asarray(raw_proxy.dataobj).astype(np.int32)
        seg_data = np.asarray(seg_proxy.dataobj).astype(np.int32)
        
        if self.mode == 'test':
            return torch.from_numpy(raw_data), torch.from_numpy(seg_data)
        
        # This is where the patch is defined
        raw_patch = torch.from_numpy(raw_data). \
            unfold(2, self.patch_size, self.patch_size). \
            unfold(1, self.patch_size, self.patch_size). \
            unfold(0, self.patch_size, self.patch_size)
        raw_patch = raw_patch.contiguous(). \
            view(-1, 1, self.patch_size, self.patch_size, self.patch_size)

        seg_patch = torch.from_numpy(seg_data). \
            unfold(2, self.patch_size, self.patch_size). \
            unfold(1, self.patch_size, self.patch_size). \
            unfold(0, self.patch_size, self.patch_size)
        seg_patch = seg_patch.contiguous(). \
            view(-1, self.patch_size, self.patch_size, self.patch_size)
        
        seg_patch_area = seg_patch.sum(axis=[1, 2, 3])
        useful_patches = np.where(seg_patch_area > 0)[0]

        return raw_patch[useful_patches], seg_patch[useful_patches]
        # return raw_patch, seg_patch


## Create train/val dataset objects using SyntheticData class and train/val dataloaders

In [6]:
patch_size = 64
train_synthetic = SyntheticData2(root_path=training_set_path, patch_size=patch_size)
# print(len(train_synthetic[0]), train_synthetic[0][0].shape, train_synthetic[0][1].shape)
val_synthetic = SyntheticData2(root_path=validation_set_path, patch_size=patch_size)
# print(len(val_synthetic[0]), val_synthetic[0][0].shape, val_synthetic[0][1].shape)

print("Train size:", len(train_synthetic))
print("Validation size:", len(val_synthetic))
print("Img size:", train_synthetic[0][0].size(), train_synthetic[0][0].dtype)
print("Segmentation size:", train_synthetic[0][1].size(), train_synthetic[0][1].dtype)

train_loader = data.DataLoader(train_synthetic, batch_size=2, shuffle=True, num_workers=0)
val_loader = data.DataLoader(val_synthetic, batch_size=2, shuffle=False, num_workers=0)
print('train_loader (total batches):', len(train_loader))
print('val_loader (total batches):', len(val_loader))


[SyntheticData2] [raw_dir_name] dataset\train\images
[SyntheticData2] [seg_dir_name] dataset\train\labels
[SyntheticData2] [patches_raw] 269 torch.Size([1, 64, 64, 64])
[SyntheticData2] [patches_seg] 269 torch.Size([64, 64, 64])
[SyntheticData2] [raw_dir_name] dataset\val\images
[SyntheticData2] [seg_dir_name] dataset\val\labels
[SyntheticData2] [patches_raw] 37 torch.Size([1, 64, 64, 64])
[SyntheticData2] [patches_seg] 37 torch.Size([64, 64, 64])
Train size: 269
Validation size: 37
Img size: torch.Size([1, 64, 64, 64]) torch.float32
Segmentation size: torch.Size([64, 64, 64]) torch.int32
train_loader (total batches): 135
val_loader (total batches): 19


# Define Models

## Conv3d_CrossHair

In [7]:
class Conv3d_CrossHair(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

        self.convD = nn.Conv3d(in_channels, out_channels, (1, kernel_size, kernel_size), stride, (0, padding, padding))
        self.convH = nn.Conv3d(in_channels, out_channels, (kernel_size, 1, kernel_size), stride, (padding, 0, padding))
        self.convW = nn.Conv3d(in_channels, out_channels, (kernel_size, kernel_size, 1), stride, (padding, padding, 0))

    def __call__(self, x):
        outD = self.convD(x)
        outH = self.convH(x)
        outW = self.convW(x)
        # print('[outD]', outD.shape)
        # print('[outH]', outH.shape)
        # print('[outW]', outW.shape)
        out = outD + outH + outW
        return out


## DeepVesselNetFCN

In [8]:
class DeepVesselNetFCN(nn.Module):
    def __init__(self, nchannels=1, nlabels=2, dim=3, batchnorm=True, dropout=False):
        super().__init__()
        self.nchannels = nchannels
        self.nlabels = nlabels
        self.dims = dim
        self.batchnorm = batchnorm
        self.dropout = dropout
        
        # 1st layer
        self.conv1 = Conv3d_CrossHair(in_channels=self.nchannels, out_channels=5, kernel_size=3, padding=1)
        self.batchnorm1 = nn.BatchNorm3d(5)
        self.dropout1 = nn.Dropout(p=0.25)
        # 2nd layer
        self.conv2 = Conv3d_CrossHair(in_channels=5, out_channels=10, kernel_size=5, padding=2)
        self.batchnorm2 = nn.BatchNorm3d(10)
        self.dropout2 = nn.Dropout(p=0.25)
        # 3rd layer
        self.conv3 = Conv3d_CrossHair(in_channels=10, out_channels=20, kernel_size=5, padding=2)
        self.batchnorm3 = nn.BatchNorm3d(20)
        self.dropout3 = nn.Dropout(p=0.25)
        # 4th layer
        self.conv4 = Conv3d_CrossHair(in_channels=20, out_channels=50, kernel_size=3, padding=1)
        self.batchnorm4 = nn.BatchNorm3d(50)
        self.dropout4 = nn.Dropout(p=0.25)
        # fully convolutional layer
        self.fcn1 = nn.Conv3d(in_channels=50, out_channels=self.nlabels, kernel_size=1)
        # Softmax layer
        self.softmax = nn.Softmax(dim=1)
        # non linearities
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

        # initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        for param in self.parameters():
            param.requires_grad = True

    def forward(self, x):
        # 1st layer
        x = self.conv1(x)
        x = self.batchnorm1(x) if self.batchnorm else x
        x = self.relu(x)
        x = self.dropout1(x) if self.dropout else x
        # 2nd layer
        x = self.conv2(x)
        x = self.batchnorm2(x) if self.batchnorm else x
        x = self.relu(x)
        x = self.dropout2(x) if self.dropout else x
        # 3rd layer
        x = self.conv3(x)
        x = self.batchnorm3(x) if self.batchnorm else x
        x = self.relu(x)
        x = self.dropout3(x) if self.dropout else x
        # 4th layer
        x = self.conv4(x)
        x = self.batchnorm4(x) if self.batchnorm else x
        x = self.relu(x)
        x = self.dropout4(x) if self.dropout else x
        # 5th layer
        x = self.fcn1(x)
        x = self.sigmoid(x)
        # classification layer
        x = self.softmax(x)
        return x

    def save(self, path):
        print('Saving model... %s' % path)
        torch.save(self, path)


## DeepVesselNetFCN2

In [9]:
class DeepVesselNetFCN2(nn.Module):
    def __init__(self, nchannels=1, nlabels=2, dim=3, batchnorm=True, dropout=False):
        super().__init__()
        self.nchannels = nchannels
        self.nlabels = nlabels
        self.dims = dim
        self.batchnorm = batchnorm
        self.dropout = dropout

        # 1st layer
        self.conv1 = Conv3d_CrossHair(in_channels=self.nchannels, out_channels=5, kernel_size=3, padding=1)
        self.batchnorm1 = nn.BatchNorm3d(5)
        self.dropout1 = nn.Dropout(p=0.25)
        # 2nd layer
        self.conv2_1 = Conv3d_CrossHair(in_channels=5, out_channels=5, kernel_size=3, padding=1)
        self.batchnorm2_1 = nn.BatchNorm3d(5)
        self.conv2_2 = Conv3d_CrossHair(in_channels=5, out_channels=10, kernel_size=3, padding=1)
        self.batchnorm2_2 = nn.BatchNorm3d(10)
        self.dropout2 = nn.Dropout(p=0.25)
        # 3rd layer
        self.conv3_1 = Conv3d_CrossHair(in_channels=10, out_channels=10, kernel_size=3, padding=1)
        self.batchnorm3_1 = nn.BatchNorm3d(10)
        self.conv3_2 = Conv3d_CrossHair(in_channels=10, out_channels=20, kernel_size=3, padding=1)
        self.batchnorm3_2 = nn.BatchNorm3d(20)
        self.dropout3 = nn.Dropout(p=0.25)
        # 4th layer
        self.conv4 = Conv3d_CrossHair(in_channels=20, out_channels=50, kernel_size=3, padding=1)
        self.batchnorm4 = nn.BatchNorm3d(50)
        self.dropout4 = nn.Dropout(p=0.25)
        # fully convolutional layer
        self.fcn1 = nn.Conv3d(in_channels=50, out_channels=self.nlabels, kernel_size=1)
        # softmax layer
        self.softmax = nn.Softmax(dim=1)
        # non linearities
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()
        
        # initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

        for param in self.parameters():
            param.requires_grad = True

    def forward(self, x):
        # 1st layer
        x = self.conv1(x)
        x = self.batchnorm1(x) if self.batchnorm else x
        x = self.relu(x)
        x = self.dropout1(x) if self.dropout else x
        # 2nd layer
        x = self.conv2_1(x)
        x = self.batchnorm2_1(x) if self.batchnorm else x
        x = self.relu(x)
        x = self.dropout2(x) if self.dropout else x
        x = self.conv2_2(x)
        x = self.batchnorm2_2(x) if self.batchnorm else x
        x = self.relu(x)
        x = self.dropout2(x) if self.dropout else x
        # 3rd layer
        x = self.conv3_1(x)
        x = self.batchnorm3_1(x) if self.batchnorm else x
        x = self.relu(x)
        x = self.dropout3(x) if self.dropout else x
        x = self.conv3_2(x)
        x = self.batchnorm3_2(x) if self.batchnorm else x
        x = self.relu(x)
        x = self.dropout3(x) if self.dropout else x
        # 4rd layer
        x = self.conv4(x)
        x = self.batchnorm4(x) if self.batchnorm else x
        x = self.relu(x)
        x = self.dropout4(x) if self.dropout else x
        # 5th layer
        x = self.fcn1(x)
        x = self.sigmoid(x)
        # classification layer
        x = self.softmax(x)
        return x

    def save(self, path):
        print('Saving model... %s' % path)
        torch.save(self, path)


## DeepVesselNet-UNet

In [10]:
class DeepVesselNet_UNet(nn.Module):
    def __init__(self, nchannels=1, nlabels=2, dim=3, init_features=32, debug=False):
        super(DeepVesselNet_UNet, self).__init__()
        self.TAG = '[DeepVesselNet_UNet]'
        self.nchannels = nchannels
        self.nlabels = nlabels
        self.dims = dim
        self.init_features = init_features
        self.debug = debug
        
        # encoder layers
        self.encoder1 = DeepVesselNet_UNet._block(nchannels, init_features, name='enc1')
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder2 = DeepVesselNet_UNet._block(init_features, init_features * 2, name='enc2')
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder3 = DeepVesselNet_UNet._block(init_features * 2, init_features * 4, name='enc3')
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder4 = DeepVesselNet_UNet._block(init_features * 4, init_features * 8, name='enc4')
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)
        # bottleneck layer
        self.bottleneck = DeepVesselNet_UNet._block(init_features * 8, init_features * 16, name='bottleneck')
        # decoder layers
        self.upconv4 = nn.ConvTranspose3d(init_features * 16, init_features * 8, kernel_size=2, stride=2)
        self.decoder4 = DeepVesselNet_UNet._block((init_features * 8) * 2, init_features * 8, name='dec4')
        self.upconv3 = nn.ConvTranspose3d(init_features * 8, init_features * 4, kernel_size=2, stride=2)
        self.decoder3 = DeepVesselNet_UNet._block((init_features * 4) * 2, init_features * 4, name='dec3')
        self.upconv2 = nn.ConvTranspose3d(init_features * 4, init_features * 2, kernel_size=2, stride=2)
        self.decoder2 = DeepVesselNet_UNet._block((init_features * 2) * 2, init_features * 2, name='dec2')
        self.upconv1 = nn.ConvTranspose3d(init_features * 2, init_features, kernel_size=2, stride=2)
        self.decoder1 = DeepVesselNet_UNet._block(init_features * 2, init_features, name='dec1')
        # output layer
        self.conv = nn.Conv3d(in_channels=init_features, out_channels=nlabels, kernel_size=1)

        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        if self.debug: print(self.TAG, '[x]', x.shape)
        enc1 = self.encoder1(x)
        if self.debug: print(self.TAG, '[enc1]', enc1.shape)
        enc2 = self.encoder2(self.pool1(enc1))
        if self.debug: print(self.TAG, '[enc2]', enc2.shape)
        enc3 = self.encoder3(self.pool2(enc2))
        if self.debug: print(self.TAG, '[enc3]', enc3.shape)
        enc4 = self.encoder4(self.pool3(enc3))
        if self.debug: print(self.TAG, '[enc4]', enc4.shape)

        bottleneck = self.bottleneck(self.pool4(enc4))
        if self.debug: print(self.TAG, '[bottleneck]', bottleneck.shape)

        dec4 = self.upconv4(bottleneck)
        if self.debug: print(self.TAG, '[upconv4]', dec4.shape)
        dec4 = torch.cat((dec4, enc4), dim=1)
        if self.debug: print(self.TAG, '[cat4]', dec4.shape)
        dec4 = self.decoder4(dec4)
        if self.debug: print(self.TAG, '[dec4]', dec4.shape)
        dec3 = self.upconv3(dec4)
        if self.debug: print(self.TAG, '[upconv3]', dec3.shape)
        dec3 = torch.cat((dec3, enc3), dim=1)
        if self.debug: print(self.TAG, '[cat3]', dec3.shape)
        dec3 = self.decoder3(dec3)
        if self.debug: print(self.TAG, '[dec3]', dec3.shape)
        dec2 = self.upconv2(dec3)
        if self.debug: print(self.TAG, '[upconv2]', dec2.shape)
        dec2 = torch.cat((dec2, enc2), dim=1)
        if self.debug: print(self.TAG, '[cat2]', dec2.shape)
        dec2 = self.decoder2(dec2)
        if self.debug: print(self.TAG, '[dec2]', dec2.shape)
        dec1 = self.upconv1(dec2)
        if self.debug: print(self.TAG, '[upconv1]', dec1.shape)
        dec1 = torch.cat((dec1, enc1), dim=1)
        if self.debug: print(self.TAG, '[cat1]', dec1.shape)
        dec1 = self.decoder1(dec1)
        if self.debug: print(self.TAG, '[dec1]', dec1.shape)
        return self.softmax(self.sigmoid(self.conv(dec1)))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (f'{name}_conv1', Conv3d_CrossHair(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1)),
                    (f'{name}_norm1', nn.BatchNorm3d(num_features=features)),
                    (f'{name}_relu1', nn.ReLU(inplace=True)),
                    (f'{name}_conv2', Conv3d_CrossHair(in_channels=features, out_channels=features, kernel_size=3, padding=1)),
                    (f'{name}_norm2', nn.BatchNorm3d(num_features=features)),
                    (f'{name}_relu2', nn.ReLU(inplace=True)),
                ]
            )
        )

    def save(self, path):
        print('Saving model... %s' % path)
        torch.save(self, path)


## DeepVesselNet-VNet

In [11]:
def passthrough(x, **kwargs):
    return x

def ELUCons(elu, nchan):
    if elu:
        return nn.ELU(inplace=True)
    else:
        return nn.PReLU(nchan)

# normalization between sub-volumes is necessary for good performance
class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm):
    def _check_input_dim(self, input):
        if input.dim() != 5:
            raise ValueError('expected 5D input (got {}D input)'.format(input.dim()))
        # super(ContBatchNorm3d, self)._check_input_dim(input)

    def forward(self, input):
        self._check_input_dim(input)
        return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, True, self.momentum, self.eps)


class LUConv(nn.Module):
    def __init__(self, nchan, elu):
        super(LUConv, self).__init__()
        self.conv1 = Conv3d_CrossHair(nchan, nchan, kernel_size=5, padding=2)
        self.bn1 = ContBatchNorm3d(nchan)
        self.relu1 = ELUCons(elu, nchan)

    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        return out


def _make_nConv(nchan, depth, elu):
    layers = []
    for _ in range(depth):
        layers.append(LUConv(nchan, elu))
    return nn.Sequential(*layers)


class InputTransition(nn.Module):
    def __init__(self, outChans, elu, debug=False):
        super(InputTransition, self).__init__()
        self.TAG = '[InputTransition]'
        self.debug = debug
        self.conv1 = Conv3d_CrossHair(1, 16, kernel_size=5, padding=2)
        self.bn1 = ContBatchNorm3d(16)
        self.relu1 = ELUCons(elu, 16)

    def forward(self, x):
        # do we want a PRELU here as well?
        out = self.bn1(self.conv1(x))
        if self.debug: print(self.TAG, '[out]', out.shape)
        # split input in to 16 channels
        x16 = torch.cat((x, x, x, x, x, x, x, x, x, x, x, x, x, x, x, x), 1)
        if self.debug: print(self.TAG, '[x16]', x16.shape)
        out = self.relu1(torch.add(out, x16))
        return out


class DownTransition(nn.Module):
    def __init__(self, inChans, nConvs, elu, dropout=False):
        super(DownTransition, self).__init__()
        outChans = 2 * inChans
        self.down_conv = Conv3d_CrossHair(inChans, outChans, kernel_size=2, stride=2)
        self.bn1 = ContBatchNorm3d(outChans)
        self.do1 = passthrough
        self.relu1 = ELUCons(elu, outChans)
        self.relu2 = ELUCons(elu, outChans)
        if dropout:
            self.do1 = nn.Dropout3d()
        self.ops = _make_nConv(outChans, nConvs, elu)

    def forward(self, x):
        down = self.relu1(self.bn1(self.down_conv(x)))
        out = self.do1(down)
        out = self.ops(out)
        out = self.relu2(torch.add(out, down))
        return out


class UpTransition(nn.Module):
    def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
        super(UpTransition, self).__init__()
        self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
        self.bn1 = ContBatchNorm3d(outChans // 2)
        self.do1 = passthrough
        self.do2 = nn.Dropout3d()
        self.relu1 = ELUCons(elu, outChans // 2)
        self.relu2 = ELUCons(elu, outChans)
        if dropout:
            self.do1 = nn.Dropout3d()
        self.ops = _make_nConv(outChans, nConvs, elu)

    def forward(self, x, skipx):
        out = self.do1(x)
        skipxdo = self.do2(skipx)
        out = self.relu1(self.bn1(self.up_conv(out)))
        xcat = torch.cat((out, skipxdo), 1)
        out = self.ops(xcat)
        out = self.relu2(torch.add(out, xcat))
        return out


class OutputTransition(nn.Module):
    def __init__(self, inChans, elu, nll):
        super(OutputTransition, self).__init__()
        self.conv1 = Conv3d_CrossHair(inChans, 2, kernel_size=5, padding=2)
        self.bn1 = ContBatchNorm3d(2)
        self.relu1 = ELUCons(elu, 2)
        self.conv2 = nn.Conv3d(2, 2, kernel_size=1)
        # if nll:
        #     self.softmax = F.log_softmax
        # else:
        #     self.softmax = F.softmax
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # convolve 32 down to 2 channels
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.conv2(out)
        out = self.sigmoid(out)
        out = self.softmax(out)
        return out


class DeepVesselNet_VNet(nn.Module):
    # the number of convolutions in each layer corresponds
    # to what is in the actual prototxt, not the intent
    def __init__(self, elu=True, nll=False, debug=False):
        super(DeepVesselNet_VNet, self).__init__()
        self.TAG = '[DeepVesselNet_VNet]'
        self.debug = debug
        self.in_tr = InputTransition(16, elu, debug=debug)
        self.down_tr32 = DownTransition(16, 1, elu)
        self.down_tr64 = DownTransition(32, 2, elu)
        self.down_tr128 = DownTransition(64, 3, elu, dropout=True)
        self.down_tr256 = DownTransition(128, 2, elu, dropout=True)
        self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=True)
        self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=True)
        self.up_tr64 = UpTransition(128, 64, 1, elu)
        self.up_tr32 = UpTransition(64, 32, 1, elu)
        self.out_tr = OutputTransition(32, elu, nll)

    def forward(self, x):
        if self.debug: print(self.TAG, '[x]', x.shape)
        out16 = self.in_tr(x)
        if self.debug: print(self.TAG, '[out16]', out16.shape)
        out32 = self.down_tr32(out16)
        if self.debug: print(self.TAG, '[out32]', out32.shape)
        out64 = self.down_tr64(out32)
        if self.debug: print(self.TAG, '[out64]', out64.shape)
        out128 = self.down_tr128(out64)
        if self.debug: print(self.TAG, '[out128]', out128.shape)
        out256 = self.down_tr256(out128)
        if self.debug: print(self.TAG, '[out256]', out256.shape)
        out = self.up_tr256(out256, out128)
        if self.debug: print(self.TAG, '[out][1]', out.shape)
        out = self.up_tr128(out, out64)
        if self.debug: print(self.TAG, '[out][2]', out.shape)
        out = self.up_tr64(out, out32)
        if self.debug: print(self.TAG, '[out][3]', out.shape)
        out = self.up_tr32(out, out16)
        if self.debug: print(self.TAG, '[out][4]', out.shape)
        out = self.out_tr(out)
        if self.debug: print(self.TAG, '[out][5]', out.shape)
        return out

    def save(self, path):
        print('Saving model... %s' % path)
        torch.save(self, path)


# Define Loss Function

In [12]:
def _categorical_crossentropy(target, output, from_logits=False, dim=-1, debug=False):
    TAG = '[_categorical_crossentropy]'
    output_dimensions = list(range(len(output.shape)))
    if dim != -1 and dim not in output_dimensions:
        raise ValueError(
            '{}{}{}'.format(
                'Unexpected channels dim {}. '.format(dim),
                'Expected to be -1 or one of the axes of `output`, ',
                'which has {} dimensions.'.format(len(output.get_shape()))))
    # Note: tf.nn.softmax_cross_entropy_with_logits
    # expects logits, Keras expects probabilities.
    if not from_logits:
        # scale preds so that the class probas of each sample sum to 1
        if debug: print(TAG, '[output]', output.shape, output.min(), output.max())
        output = output / torch.sum(output, dim, True)
        if debug: print(TAG, '[output]', output.shape, output.min(), output.max())
        # manual computation of crossentropy
        _epsilon = torch.tensor(1e-7, dtype=output.dtype).to(device)
        output = torch.clamp(output, _epsilon, 1. - _epsilon)
        if debug: print(TAG, '[output]', output.shape, output.min(), output.max())
        return target * torch.log(output)
    # else:
    #     return tf.nn.softmax_cross_entropy_with_logits(labels=target,logits=output)

def weighted_categorical_crossentropy_with_fpr(dim=1, from_logits=False, classes=2, threshold=0.5, debug=False):
    def loss(y_pred, y_true):
        TAG = '[loss]'
        if debug: print(TAG, '[y_pred]', y_pred.shape, y_pred.min(), y_pred.max())
        y_true = F.one_hot(y_true.long(), num_classes=2).permute(0, 4, 1, 2, 3)
        if debug: print(TAG, '[y_true]', y_true.shape, y_true.min(), y_true.max())
        L = _categorical_crossentropy(target=y_true, output=y_pred, dim=dim, from_logits=from_logits, debug=debug)
        if debug: print(TAG, '[L]', L.shape, L.min(), L.max(), L.sum())
        y_true_p = torch.argmax(y_true, dim=dim)
        if debug: print(TAG, '[y_true_p]', y_true_p.shape, y_true_p.min(), y_true_p.max(), y_true_p.dtype)
        y_pred_bin = (y_pred >= threshold).to(y_true.dtype) if from_logits else torch.argmax(y_pred, dim=dim)
        if debug: print(TAG, '[y_pred_bin]', y_pred_bin.shape, y_pred_bin.min(), y_pred_bin.max(), y_pred_bin.dtype)
        y_pred_probs = y_pred if from_logits else torch.max(y_pred, dim=dim)[0]
        if debug: print(TAG, '[y_pred_probs]', y_pred_probs.shape, y_pred_probs.min(), y_pred_probs.max(), y_pred_probs.dtype)
        _epsilon = 1e-7
        pred_dtype = y_pred.dtype
        # select indices where ground truth voxels are background
        y_true_neg = (y_true_p == 0).to(pred_dtype)
        if debug: print(TAG, '[y_true_neg]', y_true_neg.shape, y_true_neg.min(), y_true_neg.max(), y_true_neg.sum())
        # select indices where ground truth voxels are vessel
        y_true_pos = (y_true_p == 1).to(pred_dtype)
        if debug: print(TAG, '[y_true_pos]', y_true_pos.shape, y_true_pos.min(), y_true_pos.max(), y_true_pos.sum())
        # calculate loss weight for background class
        wa_neg = 1. / (torch.sum(y_true_neg) + _epsilon)
        if debug: print(TAG, '[wa_neg]', wa_neg)
        # calculate loss weight for vessel class
        wa_pos = 1. / (torch.sum(y_true_pos) + _epsilon)
        if debug: print(TAG, '[wa_pos]', wa_pos)
        # calculate L1 (as given in paper)
        L1 = torch.sum(wa_neg * L * y_true_neg) + torch.sum(wa_pos * L * y_true_pos)
        if debug: print(TAG, '[L1]', L1)
        # select indices where ground truth voxels are vessel but model predicted background (false negative)
        false_neg = (y_true_p != 0).to(pred_dtype) * (y_pred_bin == 0).to(pred_dtype)
        if debug: print(TAG, '[false_neg]', false_neg.shape, false_neg.min(), false_neg.max(), false_neg.sum())
        # select indices where ground truth voxels are background but model predicted vessel (false positive)
        false_pos = (y_true_p != 1).to(pred_dtype) * (y_pred_bin == 1).to(pred_dtype)
        if debug: print(TAG, '[false_pos]', false_pos.shape, false_pos.min(), false_pos.max(), false_pos.sum())
        # calculate gamma for false negative predictions (as given in paper)
        gamma_neg = 0.5 + (torch.sum(torch.abs((false_neg * y_pred_probs) - 0.5)) / (torch.sum(false_neg) + _epsilon))
        if debug: print(TAG, '[gamma_neg]', gamma_neg)
        # calculate gamma for false positive predictions (as given in paper)
        gamma_pos = 0.5 + (torch.sum(torch.abs((false_pos * y_pred_probs) - 0.5)) / (torch.sum(false_pos) + _epsilon))
        if debug: print(TAG, '[gamma_pos]', gamma_pos)
        # calculate loss weight for false negative predictions
        wb_neg = wa_neg * gamma_neg
        if debug: print(TAG, '[wb_neg]', wb_neg)
        # calculate loss weight for false positive predictions
        wb_pos = wa_pos * gamma_pos
        if debug: print(TAG, '[wb_pos]', wb_pos)
        # calculate L2 (as given in paper)
        L2 = torch.sum(wb_neg * L * false_neg) + torch.sum(wb_pos * L * false_pos)
        if debug: print(TAG, '[L2]', L2)
        # calculate total loss (as given in paper)
        total_loss = L1 + L2
        if debug: print(TAG, '[total_loss]', total_loss)
        return total_loss
    return loss

def weighted_categorical_crossentropy_with_fpr_2(dim=1, from_logits=False, classes=2, threshold=0.5, debug=False):
    def loss(y_pred, y_true):
        TAG = '[loss]'
        if debug: print(TAG, '[y_pred]', y_pred.shape, y_pred.min(), y_pred.max())
        y_true = F.one_hot(y_true.long(), num_classes=2).permute(0, 4, 1, 2, 3)
        if debug: print(TAG, '[y_true]', y_true.shape, y_true.min(), y_true.max())
        L = _categorical_crossentropy(target=y_true, output=y_pred, dim=dim, from_logits=from_logits)
        if debug: print(TAG, '[L]', L.shape, L.min(), L.max(), L.sum())
        y_true_p = torch.argmax(y_true, dim=dim)
        if debug: print(TAG, '[y_true_p]', y_true_p.shape, y_true_p.min(), y_true_p.max(), y_true_p.dtype)
        y_pred_bin = (y_pred >= threshold).to(y_true.dtype) if from_logits else torch.argmax(y_pred, dim=dim)
        if debug: print(TAG, '[y_pred_bin]', y_pred_bin.shape, y_pred_bin.min(), y_pred_bin.max(), y_pred_bin.dtype)
        y_pred_probs = y_pred if from_logits else torch.max(y_pred, dim=dim)[0]
        if debug: print(TAG, '[y_pred_probs]', y_pred_probs.shape, y_pred_probs.min(), y_pred_probs.max(), y_pred_probs.dtype)
        _epsilon = 1e-7
        C = 0
        for c in range(classes):
            c_true = (y_true_p == c).to(y_pred.dtype)
            if debug: print(TAG, c, '[c_true]', c_true.shape, c_true.dtype, c_true.min(), c_true.max())
            w = 1. / (torch.sum(c_true) + _epsilon)
            C += torch.sum(L * c_true * w)
            # Calc. FP Rate Correction
            c_false_p = (y_true_p != c).to(y_pred.dtype) * (y_pred_bin == c).to(y_pred.dtype) # Calculate false predictions
            if debug: print(TAG, c, '[c_false_p]', c_false_p.shape, c_false_p.dtype, c_false_p.min(), c_false_p.max())
            gamma = 0.5 + (torch.sum(torch.abs((c_false_p * y_pred_probs) - 0.5)) / (torch.sum(c_false_p) + _epsilon)) # Calculate Gamme
            wc = w * gamma # gamma / |Y+|
            C += torch.sum(L * c_false_p * wc) # Add FP Correction
        return C
    return loss


In [13]:
def weighted_categorical_crossentropy_with_fpr(dim=1, from_logits=False, classes=2, threshold=0.5, debug=False):
    def loss(y_pred, y_true):
        TAG = '[loss]'
        _epsilon = 1e-7
        true_dtype = y_true.dtype
        dim_123 = (1, 2, 3)

        if debug: print(TAG, '[y_pred]', y_pred.shape, y_pred.min(), y_pred.max())
        y_true = F.one_hot(y_true.long(), num_classes=2).permute(0, 4, 1, 2, 3)
        if debug: print(TAG, '[y_true]', y_true.shape, y_true.min(), y_true.max())

        # true binary volume
        y_true_p = torch.argmax(y_true, dim=dim)
        if debug: print(TAG, '[y_true_p]', y_true_p.shape, y_true_p.min(), y_true_p.max(), y_true_p.dtype)
        # predicted binary volume
        y_pred_bin = (y_pred >= threshold).to(true_dtype) if from_logits else torch.argmax(y_pred, dim=dim)
        if debug: print(TAG, '[y_pred_bin]', y_pred_bin.shape, y_pred_bin.min(), y_pred_bin.max(), y_pred_bin.dtype)

        # select indices where ground truth voxels are background
        y_true_pos = (y_true_p == 1).to(true_dtype)
        if debug: print(TAG, '[y_true_pos]', y_true_pos.shape, y_true_pos.min(), y_true_pos.max(), y_true_pos.sum(dim=dim_123))
        y_true_neg = (y_true_p == 0).to(true_dtype)
        if debug: print(TAG, '[y_true_neg]', y_true_neg.shape, y_true_neg.min(), y_true_neg.max(), y_true_neg.sum(dim=dim_123))
        # select indices where ground truth voxels are vessel

        y_prob_pos = y_pred[:, 1]
        if debug: print(TAG, '[y_prob_pos]', y_prob_pos.shape, y_prob_pos.min(), y_prob_pos.max(), y_prob_pos.dtype)
        y_prob_neg = y_pred[:, 0]
        if debug: print(TAG, '[y_prob_neg]', y_prob_neg.shape, y_prob_neg.min(), y_prob_neg.max(), y_prob_neg.dtype)

        log_prob = - y_true * torch.log(y_pred)
        if debug: print(TAG, '[log_prob]', log_prob.shape, log_prob.min(), log_prob.max(), log_prob.sum(dim=(2, 3, 4)))
        # if debug: print(TAG, '[log_prob[y_true_neg]]', log_prob[0, 0][y_true_neg[0] == 1].sum(), log_prob[0, 0].sum())
        # if debug: print(TAG, '[log_prob[y_true_pos]]', log_prob[0, 1][y_true_pos[0] == 1].sum(), log_prob[0, 1].sum())

        # calculate loss weight for vessel class
        wa_pos = 1. / (torch.sum(y_true_pos, dim=dim_123) + _epsilon)
        if debug: print(TAG, '[wa_pos]', wa_pos)
        # calculate loss weight for background class
        wa_neg = 1. / (torch.sum(y_true_neg, dim=dim_123) + _epsilon)
        if debug: print(TAG, '[wa_neg]', wa_neg)

        # calculate L1 (as given in paper)
        L1 = wa_pos * torch.sum(log_prob[:, 1]) + wa_neg * torch.sum(log_prob[:, 0])
        if debug: print(TAG, '[L1]', L1)

        # select indices where ground truth voxels are background but model predicted vessel (false positive)
        false_pos = (y_true_p == 0).to(true_dtype) * (y_pred_bin == 1).to(true_dtype)
        if debug: print(TAG, '[false_pos]', false_pos.shape, false_pos.min(), false_pos.max(), false_pos.sum(dim=dim_123))
        # select indices where ground truth voxels are vessel but model predicted background (false negative)
        false_neg = (y_true_p == 1).to(true_dtype) * (y_pred_bin == 0).to(true_dtype)
        if debug: print(TAG, '[false_neg]', false_neg.shape, false_neg.min(), false_neg.max(), false_neg.sum(dim=dim_123))

        # calculate gamma for false positive predictions (as given in paper)
        gamma_pos = 0.5 + (torch.sum(torch.abs((false_pos * y_prob_neg) - 0.5), dim=dim_123) / (torch.sum(false_pos, dim=dim_123) + _epsilon))
        if debug: print(TAG, '[gamma_pos]', gamma_pos)
        # calculate gamma for false negative predictions (as given in paper)
        gamma_neg = 0.5 + (torch.sum(torch.abs((false_neg * y_prob_pos) - 0.5), dim=dim_123) / (torch.sum(false_neg, dim=dim_123) + _epsilon))
        if debug: print(TAG, '[gamma_neg]', gamma_neg)

        # calculate loss weight for false positive predictions
        wb_pos = wa_pos * gamma_pos
        if debug: print(TAG, '[wb_pos]', wb_pos)
        # calculate loss weight for false negative predictions
        wb_neg = wa_neg * gamma_neg
        if debug: print(TAG, '[wb_neg]', wb_neg)

        # calculate L2 (as given in paper)
        L2 = wb_pos * torch.sum(log_prob[:, 0] * false_pos) + wb_neg * torch.sum(log_prob[:, 1] * false_neg)
        if debug: print(TAG, '[L2]', L2)

        # calculate total loss (as given in paper)
        total_loss = L1 + L2
        if debug: print(TAG, '[total_loss]', total_loss)

        total_mean_loss = torch.mean(total_loss)
        if debug: print(TAG, '[total_mean_loss]', total_mean_loss)

        return total_mean_loss

    return loss


# Define Metrics

In [14]:
def dice_score(y_true, y_pred):
    return f1_score(y_true.flatten(), y_pred.flatten())

def dice_information(y_true, y_pred):
    prec, rec, thres = precision_recall_curve(y_true.flatten(), y_pred.flatten())
    f1 = (2. * prec * rec) / (prec + rec)
    ind = np.argmax(f1)
    return prec[ind], rec[ind], f1[ind], thres[ind]

def threshold_accuracy(y_true, y_pred, threshold=0.5):
    pred = (y_pred >= threshold).astype(np.int32)
    return np.mean(y_true.astype(np.int32) == pred)

def categorical_accuracy(y_true, y_pred, axis=-1):
    return np.mean(np.argmax(y_true, axis=axis) == np.argmax(y_pred, axis=axis))

def dice(y_true, y_pred, smooth=1):
    intersection = np.sum(y_true * y_pred, axis=list(range(1, len(y_true.shape))))
    union = np.sum(y_true, axis=list(range(1, len(y_true.shape)))) + np.sum(y_pred, axis=list(range(1, len(y_true.shape))))
    return np.mean((2. * intersection + smooth) / (union + smooth), axis=0)

def dice_coeff(outputs, targets, smooth=1, pred=False, debug=False):
    if debug: print('[dice_coeff]', '[outputs, targets]', outputs.shape, targets.shape)
    if pred:
        if debug: print('if')
        pred = outputs
    else:
        if debug: print('else')
        _, pred = torch.max(outputs, 1)
        if debug: print('[dice_coeff]', '[pred]', np.unique(pred.detach().cpu().numpy()))
    
    if debug: print('[dice_coeff]', '[pred, targets]', pred.shape, targets.shape)
    pred = F.one_hot(pred.long(), num_classes=2)
    targets = F.one_hot(targets.long(), num_classes=2)
    if debug: print('[dice_coeff]', '[pred, targets]', pred.shape, targets.shape)

    dim = tuple(range(1, len(pred.shape) - 1))
    if debug: print('[dice_coeff]', '[dim]', dim)
    intersection = torch.sum(targets * pred, dim=dim, dtype=torch.float)
    if debug: print('[dice_coeff]', '[intersection]', intersection)
    union = torch.sum(targets, dim=dim, dtype=torch.float) + torch.sum(pred, dim=dim, dtype=torch.float)
    if debug: print('[dice_coeff]', '[union]', union)
    if debug: print('[dice_coeff]', '[intersection, union]', torch.mean(intersection), torch.mean(union))
    dice = torch.mean((2. * intersection + smooth) / (union + smooth), dtype=torch.float)
    if debug: print('[dice_coeff]', '[dice]', dice)
    return dice


# Define Solver

In [15]:
class Solver(object):

    default_optim_args = {'lr': 0.01, 'weight_decay': 0.}

    def __init__(self, optim=torch.optim.SGD, optim_args={}, loss_func=weighted_categorical_crossentropy_with_fpr()):
        self.TAG = '[Solver]'
        print(self.TAG)
        optim_args_merged = self.default_optim_args.copy()
        optim_args_merged.update(optim_args)
        self.optim_args = optim_args_merged
        self.optim = optim
        self.loss_func = loss_func
        self.best_train_dice = -1
        self.best_val_dice = -1
        self.best_train_model = None
        self.best_val_model = None

        self._reset_histories()
        self.writer = SummaryWriter()

    def _reset_histories(self):
        """Resets train and val histories for the accuracy and the loss. """
        self.train_loss_history = []
        self.train_dice_history = []
        self.train_accu_history = []
        self.val_loss_history = []
        self.val_dice_history = []
        self.val_accu_history = []
    
    def cpu_np(self, tensor):
        return tensor.detach().cpu().numpy()

    def train(self, model, train_loader, val_loader, num_epochs=10, log_nth=0, model_save_path=''):
        """
        Train a given model with the provided data.
        Inputs:
        - model: object initialized from a torch.nn.Module
        - train_loader: train data (currently using nonsense data)
        - val_loader: val data (currently using nonsense data)
        - num_epochs: total number of epochs
        - log_nth: log training accuracy and loss every nth iteration
        """

        optim = self.optim(model.parameters(), **self.optim_args)
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.9)
        self._reset_histories()
        iter_per_epoch = len(train_loader)
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        model.to(device)

        print('START TRAIN')

        for epoch in range(num_epochs):
            ##########################
            ######## TRAINING ########
            ##########################
            model.train()
            train_loss_scores = []
            train_dice_scores = []
            train_accu_scores = []

            for i, (inputs, targets) in enumerate(train_loader, 1):
                # if i > 24: break
                inputs, targets = inputs.to(device, dtype=torch.float), targets.to(device, dtype=torch.long)
                
                optim.zero_grad()
                outputs = model(inputs)
                loss = self.loss_func(outputs, targets)
                loss.backward()
                optim.step()

                preds = torch.argmax(outputs, 1)
                self.train_loss_history.append(self.cpu_np(loss))
                train_loss_scores.append(self.cpu_np(loss))
                train_dice_scores.append(self.cpu_np(dice_coeff(outputs, targets)))
                train_accu_scores.append(np.mean(self.cpu_np(preds == targets)))

                if log_nth and i % log_nth == 0:
                    train_loss_mean = np.mean(train_loss_scores[-log_nth:])
                    train_dice_mean = np.mean(train_dice_scores[-log_nth:])
                    train_accu_mean = np.mean(train_accu_scores[-log_nth:])
                    log_string = f'[TRAIN][Iter][{i + epoch * iter_per_epoch}/{iter_per_epoch * num_epochs}]'
                    log_string += f' loss: {train_loss_mean:.4f}'
                    log_string += f', dice: {train_dice_mean:.4f}'
                    log_string += f', accu: {train_accu_mean:.4f}'
                    print(log_string)
                    self.writer.add_scalar('Loss/Train', train_loss_mean, i + epoch * iter_per_epoch)
                    self.writer.add_scalar('Dice/Train', train_dice_mean, i + epoch * iter_per_epoch)
                    self.writer.add_scalar('Accuracy/Train', train_accu_mean, i + epoch * iter_per_epoch)
            
            self.writer.add_scalar('Learning-Rate', lr_scheduler.get_lr()[0], epoch + 1)
            lr_scheduler.step()
            train_loss_epoch = np.mean(train_loss_scores)
            train_dice_epoch = np.mean(train_dice_scores)
            train_accu_epoch = np.mean(train_accu_scores)
            self.train_dice_history.append(train_dice_epoch)
            self.train_accu_history.append(train_accu_epoch)

            if log_nth:
                log_string = f'[TRAIN][Epoch][{epoch + 1}/{num_epochs}]'
                log_string += f' loss: {train_loss_epoch:.4f}'
                log_string += f', dice: {train_dice_epoch:.4f}'
                log_string += f', accu: {train_accu_epoch:.4f}'
                print(log_string)
                model.save(model_save_path)

            ##########################
            ####### VALIDATION #######
            ##########################
            model.eval()
            val_loss_scores = []
            val_dice_scores = []
            val_accu_scores = []

            for i, (inputs, targets) in enumerate(val_loader, 1):
                # if i > 24: break
                inputs, targets = inputs.to(device, dtype=torch.float), targets.to(device, dtype=torch.long)

                outputs = model(inputs)
                loss = self.loss_func(outputs, targets)

                preds = torch.argmax(outputs, 1)
                self.train_loss_history.append(self.cpu_np(loss))
                val_loss_scores.append(self.cpu_np(loss))
                val_dice_scores.append(self.cpu_np(dice_coeff(outputs, targets)))
                val_accu_scores.append(np.mean(self.cpu_np(preds == targets)))

            val_loss_epoch = np.mean(val_loss_scores)
            val_dice_epoch = np.mean(val_dice_scores)
            val_accu_epoch = np.mean(val_accu_scores)
            self.val_dice_history.append(val_dice_epoch)
            self.val_accu_history.append(val_accu_epoch)
            
            if log_nth:
                log_string = f'[VAL][Epoch][{epoch + 1}/{num_epochs}]'
                log_string += f' loss: {val_loss_epoch:.4f}'
                log_string += f', dice: {val_dice_epoch:.4f}'
                log_string += f', accu: {val_accu_epoch:.4f}'
                print(log_string)
                self.writer.add_scalar('Loss/Val', val_loss_epoch, epoch + 1)
                self.writer.add_scalar('Dice/Val', val_dice_epoch, epoch + 1)
                self.writer.add_scalar('Accuracy/Val', val_accu_epoch, epoch + 1)

        print("FINISH")


# Train Network

Create model object and view it's architecture

In [16]:
model = DeepVesselNetFCN(batchnorm=True, dropout=True).to(device)
# model = DeepVesselNetFCN2(batchnorm=True, dropout=True).to(device)
# model = DeepVesselNet_UNet().to(device)
# model = DeepVesselNet_VNet().to(device)
print(model)

DeepVesselNetFCN(
  (conv1): Conv3d_CrossHair(
    (convD): Conv3d(1, 5, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))
    (convH): Conv3d(1, 5, kernel_size=(3, 1, 3), stride=(1, 1, 1), padding=(1, 0, 1))
    (convW): Conv3d(1, 5, kernel_size=(3, 3, 1), stride=(1, 1, 1), padding=(1, 1, 0))
  )
  (batchnorm1): BatchNorm3d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout1): Dropout(p=0.25, inplace=False)
  (conv2): Conv3d_CrossHair(
    (convD): Conv3d(5, 10, kernel_size=(1, 5, 5), stride=(1, 1, 1), padding=(0, 2, 2))
    (convH): Conv3d(5, 10, kernel_size=(5, 1, 5), stride=(1, 1, 1), padding=(2, 0, 2))
    (convW): Conv3d(5, 10, kernel_size=(5, 5, 1), stride=(1, 1, 1), padding=(2, 2, 0))
  )
  (batchnorm2): BatchNorm3d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout2): Dropout(p=0.25, inplace=False)
  (conv3): Conv3d_CrossHair(
    (convD): Conv3d(10, 20, kernel_size=(1, 5, 5), stride=(1, 1, 1), padding=(0, 2, 2))

Define optimizer parameters and Solver object

In [17]:
model_save_path = 'models/deepvesselnet-fcn-01.model'
total_epochs = 20
optim_args = {'lr': 0.01, 'weight_decay': 0.005, 'momentum': 0.9, 'nesterov': True}
# optim_args = {'lr': 1e-3, 'weight_decay': 0.01}
optim = torch.optim.SGD
# optim = torch.optim.Adam
print('optim_args:', optim_args)
solver = Solver(optim_args=optim_args, optim=optim, loss_func=weighted_categorical_crossentropy_with_fpr())


optim_args: {'lr': 0.01, 'weight_decay': 0.005, 'momentum': 0.9, 'nesterov': True}
[Solver]


Start training

In [18]:
# model = torch.load('models/deepvesselnet-02.model').to(device)
solver.train(model, train_loader, val_loader, log_nth=5, num_epochs=total_epochs, model_save_path=model_save_path)
model.save(model_save_path)

START TRAIN
[TRAIN][Iter][5/2700] loss: 653.2379, dice: 0.4326, accu: 0.8318
[TRAIN][Iter][10/2700] loss: 66.5997, dice: 0.5006, accu: 0.9880
[TRAIN][Iter][15/2700] loss: 872.5706, dice: 0.5314, accu: 0.9846
[TRAIN][Iter][20/2700] loss: 18.8451, dice: 0.4953, accu: 0.9793
[TRAIN][Iter][25/2700] loss: 59.5376, dice: 0.5057, accu: 0.9925
[TRAIN][Iter][30/2700] loss: 30.5637, dice: 0.4965, accu: 0.9839
[TRAIN][Iter][35/2700] loss: 266.0258, dice: 0.4977, accu: 0.9759
[TRAIN][Iter][40/2700] loss: 481.1839, dice: 0.5139, accu: 0.9821
[TRAIN][Iter][45/2700] loss: 107.8097, dice: 0.4974, accu: 0.9774
[TRAIN][Iter][50/2700] loss: 71.4449, dice: 0.5048, accu: 0.9943
[TRAIN][Iter][55/2700] loss: 145.8036, dice: 0.4972, accu: 0.9718
[TRAIN][Iter][60/2700] loss: 216.8252, dice: 0.5154, accu: 0.9917
[TRAIN][Iter][65/2700] loss: 7.4229, dice: 0.4963, accu: 0.9790
[TRAIN][Iter][70/2700] loss: 44.4985, dice: 0.5024, accu: 0.9935
[TRAIN][Iter][75/2700] loss: 24.3509, dice: 0.4974, accu: 0.9841
[TRAIN][

# Save output as nifty format

In [None]:
import numpy as np
import torch
from torch.nn import functional as F
from matplotlib.lines import Line2D
from torch.autograd import Variable
from itertools import product

def patchify(volume, patch_size, step):
    """

    :param volume:
    :param patch_size:
    :param step:
    :return:
    """
    assert len(volume.shape) == 4

    _, v_h, v_w, v_d = volume.shape

    s_h, s_w, s_d = step

    _, p_h, p_w, p_d = patch_size

    # Calculate the number of patch in each axis
    n_w = np.ceil(1.0*(v_w-p_w)/s_w+1)
    n_h = np.ceil(1.0*(v_h-p_h)/s_h+1)
    n_d = np.ceil(1.0*(v_d-p_d)/s_d+1)

    n_w = int(n_w)
    n_h = int(n_h)
    n_d = int(n_d)

    pad_w = (n_w - 1) * s_w + p_w - v_w
    pad_h = (n_h - 1) * s_h + p_h - v_h
    pad_d = (n_d - 1) * s_d + p_d - v_d
    # print(volume.shape, (0, pad_h, 0, pad_w, 0, pad_d))
    volume = F.pad(volume, (0, pad_d, 0, pad_w, 0, pad_h), 'constant')
    # print(volume.shape)
    patches = torch.zeros((n_h, n_w, n_d,)+patch_size, dtype=volume.dtype)

    for i, j, k in product(range(n_h), range(n_w), range(n_d)):
        patches[i, j, k] = volume[:, (i * s_h):(i * s_h) + p_h, (j * s_w):(j * s_w) + p_w, (k * s_d):(k * s_d) + p_d]

    return patches


def unpatchify(patches, step, imsize, scale_factor):
    """

    :param patches:
    :param step:
    :param imsize:
    :param scale_factor:
    :return:
    """
    assert len(patches.shape) == 7

    c, r_h, r_w, r_d = imsize
    s_h, s_w, s_d = tuple(scale_factor*np.array(step))

    n_h, n_w, n_d, _, p_h, p_w, p_d = patches.shape

    v_w = (n_w - 1) * s_w + p_w
    v_h = (n_h - 1) * s_h + p_h
    v_d = (n_d - 1) * s_d + p_d

    volume = torch.zeros((c, v_h, v_w, v_d), dtype=patches.dtype)
    divisor = torch.zeros((c, v_h, v_w, v_d), dtype=patches.dtype)
#     print(volume.shape, imsize)

    for i, j, k in product(range(n_h), range(n_w), range(n_d)):
        patch = patches[i, j, k]
        volume[:, (i * s_h):(i * s_h) + p_h, (j * s_w):(j * s_w) + p_w, (k * s_d):(k * s_d) + p_d] += patch
        divisor[:, (i * s_h):(i * s_h) + p_h, (j * s_w):(j * s_w) + p_w, (k * s_d):(k * s_d) + p_d] += 1
    volume /= divisor
    return volume[:, 0:r_h, 0:r_w, 0:r_d]

def test(model, volume, patch_size=64, stride=60, device=torch.cpu):
    model.eval()

    patch_size = (1, patch_size, patch_size, patch_size)
    stride = (stride, stride, stride)

    patches = patchify(volume, patch_size, stride)
    patch_shape = patches.shape
    patches = patches.view((-1,) + patch_size)
    patches = patches.cuda().type(torch.cuda.FloatTensor) if device.type == 'cuda' else patches.type(torch.FloatTensor)

    output = torch.zeros((0, ) + patch_size[1:]).type(torch.FloatTensor)

    batch_size = 5 # user input
    num = int(np.ceil(1.0 * patches.shape[0] / batch_size))

    for i in range(num):
        model_output = model.forward(patches[batch_size*i:batch_size*i + batch_size])

        _, preds = torch.max(model_output, 1)
        preds = preds.type(torch.FloatTensor)
        # preds = preds.cuda().type(torch.cuda.FloatTensor) if device.type == 'cuda' else preds.cpu().type(torch.FloatTensor)

        output = torch.cat((output, preds), 0)

    new_shape = patch_shape
    output = unpatchify(output.view(new_shape), stride, volume.shape, 1)
    output = output.squeeze(0)
    
    return output


In [22]:
model = torch.load('models/deepvesselnet-fcn2-01.model').to(device)

# create training dataset
test_synthetic = SyntheticData2(root_path=training_set_path, patch_size=patch_size, mode='test')
# index of input image
idx = 1
# get actual volume and its segmentation mask
volume, segmentation = test_synthetic[idx]
print('volume:', volume.shape)
# get input image and segmentation image paths
raw_path, seg_path = test_synthetic.get_image_path(idx)
print('raw_path:', raw_path)
print('seg_path:', seg_path)
# get affine transformation matrix of image loaded above
raw_affine = nib.load(raw_path).affine
print('raw_affine:\n', raw_affine)

# get model output on loaded image
output = test(model, volume, device=device)
print('[output]', output.shape)
dice = dice_coeff(output, segmentation, pred=True).detach().cpu().numpy()
print("Dice coefficient of output: ", dice)
print("Num seg pixels: ", np.argwhere(segmentation.detach().cpu().numpy() == 1).size)
print("Num output pixels: ", np.argwhere(output.detach().cpu().numpy() == 1).size)

# save output in local disk space
save_path = os.path.join('saved_output', os.path.split(raw_path)[-1])
print('save_path:', save_path)
out_img = nib.Nifti1Image(output.detach().cpu().numpy(), raw_affine)
nib.save(out_img, save_path)
print('output saved at:', save_path)


[SyntheticData2] [raw_dir_name] dataset\train\images
[SyntheticData2] [seg_dir_name] dataset\train\labels


MemoryError: Unable to allocate 67.0 MiB for an array with shape (292, 292, 206) and data type int32

In [None]:
# plot raw image
raw_widget = nw.NiftiWidget(raw_path)
raw_widget.nifti_plotter()

In [None]:
# plot segmentation image
seg_widget = nw.NiftiWidget(seg_path)
seg_widget.nifti_plotter()

In [None]:
# plot output image
test_widget = nw.NiftiWidget(save_path)
test_widget.nifti_plotter()

---