The dataset used for training the network can be downloaded at:

http://www.isi.uu.nl/Research/Databases/DRIVE/download.php

Extract the data into the folder `datasets/DRIVE_SRC/` such that `test` and `training` (inside the archive) become subfolders of `datasets/DRIVE_SRC/`. Then run the script `datasets/drive/prepare.py` (creates hdf5 files).

To step through this notebook, you will need to install a number of packages. First off all, we will need PyTorch, please follow the installation instructions provided at http://pytorch.org/.

After successfully doing so, install the following additional packages:
- `visdom` : package for visualisations
- `tqdm` : package to display progress bars
- `ipywidgets` : Jupyter notebook widgets

(Use `conda` or `pip` / `pip3` depending on your local setup.)

You can skip execution of the next cell. 

(Executing the next cell enables presentation mode (navigate with arrow keys in cell mode); to get out of presentation mode, clear all cell output -- the menu becomes visible on hover)

In [None]:
%%html
<link rel="stylesheet" href="css/jupyter.css">
<link rel="stylesheet" href="css/presenter.css">
<link rel="stylesheet" href="css/cells.css">
<link rel="stylesheet" href="css/codemirror.css">

In [None]:
# package imports

# numpy
import numpy as np
np.random.seed(42)

# progress bars
from tqdm import tqdm_notebook as tqdm

# in case GPUs are used, limit to single device
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# PyTorch imports
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision

from torch import nn
from torch.autograd import Variable
from torchvision import transforms

dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor  # Uncomment this to run on GPU

# matplotlib for plotting
import matplotlib.pyplot as plt
fig_size = (7, 7)
plt.rcParams['axes.spines.left'] = False
plt.rcParams['axes.spines.bottom'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['figure.figsize'] = fig_size
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['image.interpolation'] = 'none'
plt.rcParams['xtick.top'] = False
plt.rcParams['xtick.bottom'] = False
plt.rcParams['xtick.color'] = 'white'
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.right'] = False
plt.rcParams['ytick.color'] = 'white'
%matplotlib inline

# widgets
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

# visdom
import visdom

<span class="big"><b>UNet in PyTorch</span>

<br/>

<h2 style="line-height: 1.4em; font-size: 1.7em;">MIE Deep Learning Bootcamp, Berlin 2018</h2>

<span class='big'>Goal: UNet architecture</span>

![](img/u-net-architecture.png)

Ronneberger et al., 2015

# DRIVE database

![](img/task_new.png)

# Drive database

http://www.isi.uu.nl/Research/Databases/DRIVE/download.php

![](img/drive.png?3)

In [None]:
from datasets.drive.extract_patches import get_data_training, get_data_testing

img_size = 64

patches_imgs, patches_masks = get_data_training(N=2000, img_size=img_size)
dataset_train = [(img, mask) for img, mask in zip(patches_imgs, patches_masks)]

test_imgs, test_masks = get_data_testing()

In [None]:
def disp(i):
    plt.figure(figsize=(10,10))
    plt.subplot(1,2,1)
    plt.imshow(patches_imgs[i,:,:,:].T)
    plt.subplot(1,2,2)
    plt.imshow(patches_masks[i,:,:,:].squeeze().T)
    plt.show()
    
interact(disp, i=(0, len(patches_imgs)-1));

In [None]:
def disp(i):
    plt.figure(figsize=(10,10))
    plt.subplot(1,2,1)
    plt.imshow(test_imgs[i,:,:,:].T)
    plt.subplot(1,2,2)
    plt.imshow(test_masks[i,:,:,:].squeeze().T)
    plt.show()
    
interact(disp, i=(0, len(test_imgs)-1));

In [None]:
from random import shuffle

def batch_generator(dataset, batch_size=5):
    shuffle(dataset)
    N_full_batches = len(dataset) // batch_size
    for i in range(N_full_batches):
        idx_from = batch_size * i
        idx_to = batch_size * (i + 1)
        imgs, masks = zip(*[(img, mask) for img, mask in dataset[idx_from:idx_to]])
        yield imgs, masks

In [None]:
bgen = batch_generator(dataset_train, batch_size=250)

In [None]:
imgs, masks = next(bgen)
len(imgs)

<span class='big'>Unet architecture</span>

![](img/u-net-architecture.png)

In [None]:
import torch.nn as nn

class UnetConv(nn.Module):
    def __init__(self, 
                 in_channels, out_channels, 
                 kernel=3, stride=1, padding=1,
                 act=nn.ReLU()):
        super(UnetConv, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, 
                              kernel, stride, padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.act = act
        
    def forward(self, inputs):
        outputs = self.conv(inputs)
        outputs = self.norm(outputs)
        if self.act is not None:
            return self.act(outputs)
        else:
            return outputs

In [None]:
class UnetDown(nn.Module):
    def __init__(self, kernel=2):
        super(UnetDown, self).__init__()
        self.down = nn.MaxPool2d(kernel)

    def forward(self, inputs):
        return self.down(inputs)

In [None]:
class UnetUp(nn.Module):
    def __init__(self, 
                 in_channels, out_channels, 
                 kernel=2, stride=2, padding=(0, 0, 0, 0)):
        super(UnetUp, self).__init__()
        
        self.padding = padding
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, 
                                         kernel, stride, padding=0)
        self.norm = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU()
        
    def forward(self, inputs):
        outputs = F.pad(inputs, self.padding)
        outputs = self.deconv(outputs)
        outputs = self.norm(outputs)
        return self.act(outputs)

<span class='big'>Unet architecture</span>

![](img/u-net-architecture.png)

In [None]:
class UnetConc(nn.Module):
    def __init__(self, dropout=0.5):
        super(UnetConc, self).__init__()

        if dropout is not False and dropout > 0.:
            self.dropout = torch.nn.Dropout()
        else:
            self.dropout = None
        
    def forward(self, inputs1, inputs2):
        x = torch.cat([inputs1, inputs2], 1)
        
        if self.dropout is not None:
            x = self.dropout(x)
        
        return x

In [None]:
class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        
        self.conv1 = UnetConv(3, 32)
        self.conv2 = UnetConv(32, 32)
        self.conv3 = UnetConv(32, 32)
        self.conv4 = UnetConv(32, 32)
        self.conv5 = UnetConv(32, 32)
        self.conv6 = UnetConv(32, 32)
        self.conv7 = UnetConv(32, 32)
        self.conv8 = UnetConv(64, 32)
        self.conv9 = UnetConv(32, 32)
        self.conv10 = UnetConv(64, 32)
        self.conv11 = UnetConv(32, 32)
        self.conv12 = UnetConv(64, 32)
        self.conv13 = UnetConv(32, 32)
        self.conv14 = UnetConv(64, 32)
        self.conv15 = UnetConv(32, 32)
        self.conv16 = UnetConv(32, 1, act=None)
    
        self.down = UnetDown()

        self.up1 = UnetUp(32, 32)
        self.up2 = UnetUp(32, 32)
        self.up3 = UnetUp(32, 32)
        self.up4 = UnetUp(32, 32)
        self.up5 = UnetUp(32, 32)
        
        self.conc = UnetConc()
        
    def forward(self, x):
        conv1 = self.conv1(x)
        
        down1 = self.down(conv1)
        conv2 = self.conv2(down1)

        down2 = self.down(conv2)
        conv3 = self.conv3(down2)
        
        down3 = self.down(conv3)
        conv4 = self.conv4(down3)

        down4 = self.down(conv4)
        conv5 = self.conv5(down4)

        down5 = self.down(conv5)
        conv6 = self.conv6(down5)
        
        up1 = self.up1(conv6)
        conv7 = self.conv6(up1)
        conc1 = self.conc(conv7, down4)
        conv8 = self.conv8(conc1)

        up2 = self.up2(conv8)
        conv9 = self.conv9(up2)
        conc2 = self.conc(conv9, down3)
        conv10 = self.conv10(conc2)
            
        up3 = self.up3(conv10)
        conv11 = self.conv11(up3)
        conc3 = self.conc(conv11, down2)
        conv12 = self.conv12(conc3)

        up4 = self.up4(conv12)
        conv13 = self.conv13(up4)
        conc4 = self.conc(conv13, down1)
        conv14 = self.conv14(conc4)

        up5 = self.up5(conv14)
        conv15 = self.conv15(up5)
        conv16 = self.conv16(conv15)

        outputs = nn.Sigmoid()(conv16)
        
        return outputs

In [None]:
net = Unet()
net

In [None]:
inputs = Variable(torch.rand(1, 3, 64, 64))
net(inputs).size()

In [None]:
# bash
# python -m visdom.server -port 9000

import visdom

vis = visdom.Visdom(port=9000)

In [None]:
if torch.cuda.is_available():
    net.cuda()
else:
    print('CUDA not available')

In [None]:
criterion = nn.MSELoss()

In [None]:
optimizer = optim.Adam(net.parameters(), lr=0.01, 
                       betas=(0.9, 0.995), eps=1e-05)

In [None]:
batch_size = 32
n_epochs = 50

iteration = 0
vline = vis.line(X=np.asarray([-1, -1]),
                 Y=np.asarray([np.nan, np.nan]))

for epoch in tqdm(range(n_epochs)): 
    bgen = batch_generator(dataset_train, batch_size)
    
    for idx, (imgs, masks) in enumerate(bgen):
        
        imgs = np.asarray(imgs).reshape(batch_size, 3, img_size, img_size)
        masks = np.asarray(masks).reshape(batch_size, 1, img_size, img_size)

        inputs = Variable(torch.from_numpy(imgs).type(dtype))
        targets = Variable(torch.from_numpy(masks).type(dtype))

        optimizer.zero_grad()
        pred = net(inputs)

        loss = criterion(pred, targets)
        loss.backward()

        optimizer.step()

        iteration += 1
        current_loss = np.asarray([loss.data[0]])
        
        vis.updateTrace(X=np.asarray([iteration]), 
                        Y=current_loss, 
                        win=vline) 

In [None]:
net.load_state_dict(torch.load('weights/32_epochs_state.pkl'))

In [None]:
pred = net(Variable(torch.from_numpy(test_imgs[:3,:,:,:]).type(dtype)))
pred.size()

In [None]:
outputs = pred.data.numpy()

def disp(i):
    plt.figure(figsize=(10,10))
    plt.subplot(1,2,1)
    plt.imshow(outputs[i,:,:,:,].T.squeeze());
    plt.subplot(1,2,2)
    plt.imshow(test_masks[i,:,:,:].T.squeeze());
    plt.show()
    
interact(disp, i=(0, len(outputs)-1));