# Medical Deep Learning

---

In this sheet we will set up the foundations to load the data, create a fully-convolutional network (FCN) for image segmentation, and train it to solve multi-class segmentation problems for medical applications. 

In [None]:
!wget https://cloud.imi.uni-luebeck.de/s/zFyEiJKNtaKKzS8/download -O AbdomenPreAffine.zip
!unzip -o AbdomenPreAffine.zip > /dev/null  # disable the output

In [None]:
!wget https://cloud.imi.uni-luebeck.de/s/Fd63J7xMLmkMEzb/download -O mdl_exercise1_utils.py

from mdl_exercise1_utils import init_weights, Plotter, ZeroPad, Crop, Scale, ToCuda

Let's get started with the code and run all the imports:

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch import Tensor
from tqdm import tqdm_notebook as tqdm

import nibabel as nib
import os
from collections import OrderedDict
from typing import Callable, Any, Optional, List

import matplotlib.pyplot as plt

The **training data** is loaded from the filesystem as:

In [None]:
imgs2 = torch.randn(20,1,128,128,128).cuda()#/500
segs2 = torch.randint(1,(20,128,128,128)).long().cuda()
list_train = torch.Tensor([2,3,4,5,7,8,10,21,22,24,25,27,28,30,31,33,34,36,37,39,40])
list_test = torch.Tensor([1,4,7,10,23,26,29,32,35,38]).long()
for i in range(20):
    img = nib.load('/content/AbdomenPreAffine/Training/img/img00'+str(int(list_train[i])).zfill(2)+'.nii.gz').get_fdata()
    imgs2[i:i+1] = F.interpolate(torch.from_numpy(img).cuda().unsqueeze(0).unsqueeze(1).float(),size=(128,128,128),mode='trilinear').cuda()/500
    seg = nib.load('/content/AbdomenPreAffine/Training/label/label00'+str(int(list_train[i])).zfill(2)+'.nii.gz').get_fdata()
    segs2[i] = F.interpolate(torch.from_numpy(seg).cuda().unsqueeze(0).unsqueeze(1).float(),size=(128,128,128),mode='nearest').squeeze().cuda().long()
    print('Loaded', i+1, '/',20)

The **validation data** is loaded from the filesystem as:

In [None]:
imgs_val = torch.randn(10,1,128,128,128).cuda()#/500
segs_val = torch.randint(2,(10,128,128,128)).long().cuda()
for i in range(10):
    img = nib.load('/content/AbdomenPreAffine/Training/img/img00'+str(int(list_test[i])).zfill(2)+'.nii.gz').get_fdata()
    imgs_val[i:i+1] = F.interpolate(torch.from_numpy(img).unsqueeze(0).unsqueeze(1).float(),size=(128,128,128),mode='trilinear').cuda()/500
    seg = nib.load('/content/AbdomenPreAffine/Training/label/label00'+str(int(list_test[i])).zfill(2)+'.nii.gz').get_fdata()
    segs_val[i] = F.interpolate(torch.from_numpy(seg).unsqueeze(0).unsqueeze(1).float(),size=(128,128,128),mode='nearest').squeeze().cuda().long()
    print('Loaded', i+1, '/',10)

In [None]:
segs2 = segs2.unsqueeze(1)
segs_val = segs_val.unsqueeze(1)

All data is being loaded now, thus we can start with the implementation of the afine augementation.

In [None]:
class AugmentAffine(object):
  def __init__(self, strength=0.05):
    self.strength = strength

  def __call__(self, sample):
    B, C, D, H, W = sample# access image data
    
    offsets = 0.001 # random offsets
    
    affine_matrix = (torch.eye(3,4).unsqueeze(0) + self.strength * offsets)  
    affine_matrix.cuda()

    meshgrid = F.affine_grid(affine_matrix,sample.size())# resampling grid

    sample['image'] = F.grid_sample(sample['image'],affine_matrix)# resample image
    sample['label'] = F.grid_sample(sample['label'],affine_matrix)# resample label
    
    return sample
augmentation_training = [AugmentAffine(0.1), ToCuda()]
augmentation_validate = [ToCuda()]


Define a simplified Fully-Convolutional Network (FCN) architecture



In [None]:
class Net(nn.Module):
  def __init__(self, unet=True):
    super().__init__()
    #input[20,1,128,128,128]
    self.block0 = nn.Sequential(nn.BatchNorm3d(1),
                                nn.Conv3d(1, 16, 3, padding=1),
                                nn.ReLU(),
                                nn.BatchNorm3d(16),
                                nn.Conv3d(16, 16, 3, padding=1),
                                nn.ReLU()
                               )
                    
    self.mp01 = nn.MaxPool3d(2, 2)

    # add two more blocks and two more 2x2 poolings
    self.block1 = nn.Sequential(nn.BatchNorm3d(16),
                                nn.Conv3d(16,32,3,padding=1),
                                nn.ReLU(),
                                nn.BatchNorm3d(32),
                                nn.Conv3d(32,32,3,padding=1),
                                nn.ReLU()
                                )
    self.mp02 = nn.MaxPool3d(2,2)

    self.block2 = nn.Sequential(nn.BatchNorm3d(32),
                                nn.Conv3d(32,64,3,padding=1),
                                nn.ReLU(),
                                nn.BatchNorm3d(64),
                                nn.Conv3d(64,64,3,padding=1),
                                nn.ReLU()
                                )
    self.mp03 = nn.MaxPool3d(2,2)
 #output[20,64,16,16,16]
  
    # add final classifiation block (1x1 convs instead of linear layers)
    self.upsample = nn.Sequential(nn.Conv3d(64,32,1),
                                  nn.ReLU(),
                                  nn.Conv3d(32,1,1),
                                  nn.Sigmoid(),
                                  )
    
    #output [20,1,128,128,128]
    
    
  def forward(self, inputs):
    output0 = self.block0(inputs)
    output1 = self.mp01(output0)
    output2 = self.block1(output1)
    output3 = self.mp02(output2)
    output4 = self.block2(output3)
    output5 = self.mp03(output4)
    
    # Add forwards
    
    return F.interpolate(self.upsample(output5), scale_factor= 8)  # Add both values

Let's set up the experiment.

In [None]:
# Hyper-parameters
n_epochs = 31

# Visualise progress every 5th epoch
every_epoch = 5
bin_thresh = 0.3
plotter = Plotter(n_epochs//every_epoch, z_slice=23, bin_thresh=bin_thresh)

# Network initialisation
net = Net(False).cuda()
net.apply(init_weights)

# Set up optimisation for training process
criterion = nn.BCELoss().cuda()
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)

# Save loss of each epoch
losses_training = []
losses_validate = []

# Show current progress and loss
progress = tqdm(range(n_epochs), desc='progress')

# making dict

data = {
  "image": imgs2,
  "label": segs2,
}

data_val = {
  "image": imgs_val,
  "label": segs_val,
}

######################
# MAIN TRAINING LOOP #
######################

for epoch in progress:
  
  ########################################
  #               TRAINING               #
  ########################################

  sum_loss = 0
  
  # Parameters must be trainable
  net.train()
  with torch.set_grad_enabled(True):  # add bool value
    
    # loop to process all training samples (packed into batches)
    
    for i in range(20):# draw training sample
      
      result = net.forward(data['image'][i].unsqueeze(0))# forward run with sample
      loss =  criterion(result, data['label'][i].unsqueeze(0).float())# compute BCE loss

      # backward step to compute gradients for optimising the model weights

      sum_loss += loss.item()
  
  losses_training.append(sum_loss / len(list_train))# add number of training samples)

  if epoch % every_epoch == 0:
    plotter.add_training_sample(data, result, epoch)
  
  
  ########################################
  #              VALIDATION              #
  ########################################

  sum_loss = 0
  
  # Parameters must not be trainable
  net.eval()
  with torch.set_grad_enabled(False):  # add bool value
    
    # loop to process all validation samples (packed into batches)
    for i in range(10):# draw validation sample
      
      # copy and paste the lines required from the training step
      result = net.forward(data_val['image'][i].unsqueeze(0))# forward run with sample

      loss =  criterion(result,data_val['label'][i].unsqueeze(0).float())# compute BCE loss
      
      sum_loss += loss.item()
  
  losses_validate.append(sum_loss / len(list_test))# add number of validation samples))
  
  if epoch % every_epoch == 0:
    plotter.add_validation_sample(data_val, result, epoch)
  
  progress.set_postfix(loss=losses_training[-1], val_loss=losses_validate[-1])
