# U-Net for myelin segmentation

Simple unet approach for myelin segmentaiton, based on simplified version of pytorch unet exercise from webinar 3.

## Dependencies

In [None]:
%load_ext tensorboard

In [None]:
import os
from functools import partial
from glob import glob

import imageio
import napari
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn
from torch.nn import functional as F

## Check the data

Check that the data is correct in napari.

In [None]:
# TODO adapt this to where you have stored the data
root_dir = os.path.expanduser("~/Work/data/dl-course-2022/prepared_data_v1")

# select all train images and labels using glob
# we need to sort afterwards to make sure that the order is the same, because glob does not sort the filepaths
train_images = glob(os.path.join(root_dir, "train", "images", "*tif"))
train_images.sort()
train_labels = glob(os.path.join(root_dir, "train", "labels", "*.tif"))
train_labels.sort()
assert len(train_images) == len(train_labels)

# load images and labels into memory
train_images = [imageio.imread(im) for im in train_images]
train_labels = [imageio.imread(lab) for lab in train_labels]

In [None]:
# display images and labels in napari
def view_image_and_labels(image, label):
    v = napari.Viewer()
    v.add_image(image)
    v.add_labels(label)

In [None]:
# this will display all images and labels at once, quickly have a look that they match!
for image, label in zip(train_images, train_labels):
    view_image_and_labels(image, label)

## Implement the training pipeline and model

In [None]:
# normalize the image so that it is normalized from zero to one
def normalize_image(image):
    min_ = image.min()
    max_ = image.max()
    eps = 1.0e-7
    normalized_image = (image - min_ + eps) / (max_ + eps)
    return normalized_image

# transform a label image to one hot encoding.
# e.g. transform a 2d label image with label ids [1, 2]:
# [[0, 0, 1],
#  [0, 1, 1],
#   0, 2, 2]]
# into an image with two channels:
# first channel with mask for label 1
# [[0, 0, 1],
#  [0, 1, 1],
#  [0, 0, 0]]
# second channel with mask for label 2
# [[0, 0, 0],
#  [0, 0, 0],
#  [0, 1, 1]] 
# note that we treat zero as background here
def one_hot_encoding(labels, n_classes):
    target = np.zeros((n_classes,) + labels.shape, dtype="float32")
    for chan_id, class_id in enumerate(range(1, n_classes + 1)):
        target[chan_id] = labels == class_id
    return target

In [None]:
# any PyTorch dataset class should inherit from torch.utils.data.Dataset
class MyelinDataset(Dataset):
    """ A PyTorch dataset to provide the images and labels."""
    def __init__(self, images, labels, patch_shape, transform=None):
        self.images = images
        self.labels = labels
        assert len(self.images) == len(self.labels)
        
        # the patch shape is the image size returned by this data loader, e.g. 512x512
        self.patch_shape = patch_shape
        
        # determine the numper of samples and classes from the data
        self.n_samples = len(images)
        self.n_classes = self.compute_n_classes(self.labels)
        
        # the transformation applied to the input image, here we just use normalization
        self.image_transform = normalize_image
        # the transformation applied to the input labels, here we transform the label image into multi-channel binary masks
        self.label_transform = partial(one_hot_encoding, n_classes=self.n_classes)
        # transformations applied to both images and labels.
        # this can for example be used for data augmentation
        self.transform = transform
    
    # compute the number of classes in our labels
    def compute_n_classes(self, labels):
        # compute all unique values in the label images
        unique_label_values = np.unique(np.concatenate([lab.flatten() for lab in labels]))
        # the number of classes is the number of unique labels - 1
        # (because we don't take into account zero)
        return len(unique_label_values) - 1

    # get the total number of samples
    def __len__(self):
        return self.n_samples

    # fetch the training sample given its index
    def __getitem__(self, idx):
        image = self.images[idx]
        labels = self.labels[idx]
        assert image.shape == labels.shape
        
        # sample a random patch for this image
        start_coordinates = [np.random.randint(0, shape - pshape) for shape, pshape in zip(image.shape, self.patch_shape)]
        patch = tuple(slice(start, start + shape) for start, shape in zip(start_coordinates, self.patch_shape))
        
        # get the image and labels from the patch
        input_ = np.asarray(image[patch])
        target = np.asarray(labels[patch])
        
        # apply the transformations
        if self.image_transform is not None:
            input_ = self.image_transform(input_)
        if self.label_transform is not None:
            target = self.label_transform(target)
        if self.transform is not None:
            input_, target = self.transform(input_, target)
        
        # make sure the input has a channel axis
        if input_.ndim == 2:
            input_ = input_[None]
        
        return input_, target

Now let's load the dataset and visualize it with a simple function:

In [None]:
# patch_shape = (1024, 1024)
# small patch shape for testing, choose a bigger one in your experiments (e.g. above)
patch_shape = (512, 512)
train_dataset = MyelinDataset(train_images, train_labels, patch_shape)

In [None]:
def show_sample_from_dataset(dataset):
    idx = np.random.randint(0, len(dataset)) # take a random sample
    image, target = train_dataset[idx]
    print(image.shape)
    print(target.shape)
    v = napari.Viewer()
    v.add_image(image)
    v.add_image(target)

In [None]:
show_sample_from_dataset(train_dataset)

In [None]:
# the unet is copied directly from the webinar exercise
class UNet(nn.Module):
    """ UNet implementation
    Arguments:
      in_channels: number of input channels
      out_channels: number of output channels
      final_activation: activation applied to the network output
    """
    
    # _conv_block and _upsampler are just helper functions to
    # construct the model.
    # encapsulating them like so also makes it easy to re-use
    # the model implementation with different architecture elements
    
    # Convolutional block for single layer of the decoder / encoder
    # we apply to 2d convolutions with relu activation
    def _conv_block(self, in_channels, out_channels):
        return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                             nn.ReLU(),
                             nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                             nn.ReLU())       


    # upsampling via transposed 2d convolutions
    def _upsampler(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels,
                                kernel_size=2, stride=2)
    
    def __init__(self, in_channels=1, out_channels=1, 
                 final_activation=None):
        super().__init__()
        
        # the depth (= number of encoder / decoder levels) is
        # hard-coded to 4
        self.depth = 4

        # the final activation must either be None or a Module
        if final_activation is not None:
            assert isinstance(final_activation, nn.Module), "Activation must be torch module"
        
        # all lists of conv layers (or other nn.Modules with parameters) must be wraped
        # itnto a nn.ModuleList
        
        # modules of the encoder path
        self.encoder = nn.ModuleList([self._conv_block(in_channels, 16),
                                      self._conv_block(16, 32),
                                      self._conv_block(32, 64),
                                      self._conv_block(64, 128)])
        # the base convolution block
        self.base = self._conv_block(128, 256)
        # modules of the decoder path
        self.decoder = nn.ModuleList([self._conv_block(256, 128),
                                      self._conv_block(128, 64),
                                      self._conv_block(64, 32),
                                      self._conv_block(32, 16)])
        
        # the pooling layers; we use 2x2 MaxPooling
        self.poolers = nn.ModuleList([nn.MaxPool2d(2) for _ in range(self.depth)])
        # the upsampling layers
        self.upsamplers = nn.ModuleList([self._upsampler(256, 128),
                                         self._upsampler(128, 64),
                                         self._upsampler(64, 32),
                                         self._upsampler(32, 16)])
        # output conv and activation
        # the output conv is not followed by a non-linearity, because we apply
        # activation afterwards
        self.out_conv = nn.Conv2d(16, out_channels, 1)
        self.activation = final_activation
    
    def forward(self, input):
        x = input
        # apply encoder path
        encoder_out = []
        for level in range(self.depth):
            x = self.encoder[level](x)
            encoder_out.append(x)
            x = self.poolers[level](x)

        # apply base
        x = self.base(x)
        
        # apply decoder path
        encoder_out = encoder_out[::-1]
        for level in range(self.depth):
            x = self.upsamplers[level](x)
            x = self.decoder[level](torch.cat((x, encoder_out[level]), dim=1))
        
        # apply output conv and activation (if given)
        x = self.out_conv(x)
        if self.activation is not None:
            x = self.activation(x)
        return x

In [None]:
# we use the dice coefficient as loss function
class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
        
    # the dice coefficient of two sets represented as vectors a, b ca be 
    # computed as (2 *|a b| / (a^2 + b^2))
    def forward(self, prediction, target):
        assert prediction.shape == target.shape, f"{prediction.shape}, {target.shape}"
        # compute the dice_score for each channel independently
        # Note that the tensor has the shape BATCHES X CHANNELS X WIDTH X HEIGHT
        dice_loss = 0.0
        n_channels = prediction.shape[1]
        for channel_id in range(n_channels):
            pred, trgt = prediction[:, channel_id], target[:, channel_id]
            intersection = (pred * trgt).sum()
            denominator = (pred * pred).sum() + (trgt * trgt).sum()
            dice_score = (2 * intersection / denominator.clamp(min=self.eps))
            # we use 1 - the dice score as a loss, so that lower values correspond to a better solution
            # (as required for a loss function)
            # note that a perfect match corresponds to a dice score of 1 and a complete miss to a dice score of 0
            dice_loss += 1.0 - dice_score
        # normalize the dice loss by the number of channels
        return dice_loss / n_channels

## Training

Implement and run training for the model.

In [None]:
# apply training for one epoch
def train(model, loader, optimizer, loss_function,
          epoch, log_interval=100, log_image_interval=20, tb_logger=None):

    # set the model to train mode
    model.train()
    # iterate over the batches of this epoch
    for batch_id, (x, y) in enumerate(loader):
        # move input and target to the active device (either cpu or gpu)
        x, y = x.to(device), y.to(device)
        
        # zero the gradients for this iteration
        optimizer.zero_grad()
        
        # apply model, calculate loss and run backwards pass
        prediction = model(x)
        loss = loss_function(prediction, y)
        loss.backward()
        optimizer.step()

        # log to console
        if batch_id % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                  epoch, batch_id * len(x),
                  len(loader.dataset),
                  100. * batch_id / len(loader), loss.item()))

       # log to tensorboard
        if tb_logger is not None:
            step = epoch * len(loader) + batch_id
            tb_logger.add_scalar(tag='train_loss', scalar_value=loss.item(), global_step=step)
            # check if we log images in this iteration
            if step % log_image_interval == 0:
                img = img_tensor=x.to('cpu')
                tb_logger.add_images(tag='input', img_tensor=img, global_step=step)
                tb_logger.add_images(tag='target', img_tensor=y.to('cpu'), global_step=step)
                tb_logger.add_images(tag='prediction', img_tensor=prediction.to('cpu').detach(), global_step=step)


This time we will use GPU to train faster. Please make sure that your Notebook is running on GPU. 

In [None]:
# check if we have  a gpu
if torch.cuda.is_available():
    print("GPU is available")
    device = torch.device("cuda")
else:
    print("GPU is not available")
    device = torch.device("cpu")

In [None]:
# start a tensorboard writer
logger = SummaryWriter('runs/Unet')
%tensorboard --logdir runs

In [None]:
# build a default unet with sigmoid activation
# to normalize predictions to [0, 1]
net = UNet(1, 3, final_activation=nn.Sigmoid())
# move the model to GPU
net = net.to(device)

# create the loader from the training dataset
batch_size = 1  # the batch size used for training
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
# TODO create the validation dataset and validation loader

# define the loss function and optimizer
loss_function = DiceLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1.0e-4)

# TODO: define a metric to be used for validation

# train for a number of epochs
# during the training you can inspect the 
# predictions in the tensorboard
n_epochs = 25
for epoch in range(n_epochs):
    # run training for this epoch
    train(net, train_loader, optimizer, loss_function, epoch, tb_logger=logger)
    step = epoch * len(train_loader.dataset)
    # validate
    # TODO: implement the validation here