# Week 14 - Deep Learning for Image Segmentation: U-Net

## Biomedical Computer Vision Group (BMCV) <br> BioQuant, IPMB, Heidelberg University

Image segmentation is an important task for biomedical image analysis enabling many downstream tasks, such as cell counting or tracking. Image segmentation assigns a specific class to each pixel of an image. In recent years, deep learning based methods often yield better segmentation results compared to classical methods.

Fully convolutional neural networks enable classifying all pixels of entire images in a single forward pass. The [U-Net Architecture](https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28) (Ronneberger et al., MICCAI, 2015) builds upon this idea and was proposed for segmentation of biomedical images. Even today, the U-Net and its extensions are still widely used in many applications and often yield state-of-the-art performances.

This notebook comprises the following sections:

1) [Loading and Visualizing the Data](#1---data)
2) [Implementing the U-Net](#2---2d-u-net)
3) [Data Slicing and Augmentation](#3---data-slicing-and-augmentation)
4) [Training, Validation and Testing](#4---training-validation-and-testing)

Here you can jump to the exercises:

[Exercise 1](#exercise-1-u-net-decoder), 
[Exercise 2](#exercise-2),
[Exercise 3](#exercise-3), 
[Exercise 4](#exercise-4), 
[Exercise 5](#exercise-5)

# 0 - Import Packages and Define Helper Functions

In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import skimage.io
import scipy.ndimage

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision.transforms import Compose

from skimage.transform import resize


In [None]:
#Setting random seeds for reproducable results
random_seed = 1
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True

#Setting the device on which the U-Net runs
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Network will run on "{device}".')

In [None]:
def plot_images(*args):
    f,ax = plt.subplots(1,len(args), figsize = (20,20))
    for i in range(len(args)):
        ax[i].imshow(args[i][0], cmap='gray')
        ax[i].set_title(args[i][1])
        ax[i].axis('off')
    plt.show()

def plot_results(*args):
    for i in range(len(args)):
        plt.plot(args[i][0], label=args[i][1])

    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid()
    plt.show()

# 1 - Data

The data required for this notebook show HeLa cells on a flat glass acquired with differential interference contrast (DIC) microscopy.

In [None]:
import re
import io
import zipfile
from pathlib import Path

import requests

def create_url(mode, dataset):
    """
    Construct URL of CTC dataset.

    CTC datasets have a train and test dataset for each image sequence.
    """
    m = mode + "ing" if mode == "train" else mode
    return f"http://data.celltrackingchallenge.net/{m}-datasets/{dataset}.zip"

def download_dataset(base_dir, dataset_name, mode="train"):
    """
    Download and extract CTC dataset.
    """
    # setup paths
    data_dir = base_dir / mode
    dataset_dir = data_dir / dataset_name

    # check if data was already downloaded
    if dataset_dir.exists():
        print(f"Dataset {dataset_name} already exists.")
    else:
        # if data was not downloaded, yet, create folder it will be saved to and CTC url
        dataset_dir.mkdir(parents=True)
        download_url = create_url(mode, dataset_name)

        # download data and check that URL actually exists
        print(f"Downloading {mode} data from {dataset_name} ({download_url}) to {dataset_dir}/{dataset_name}.")
        r = requests.get(download_url)
        assert r.status_code == 200

        # extract zip file to data folder
        z = zipfile.ZipFile(io.BytesIO(r.content))
        z.extractall(data_dir)

    print("Download finished.")
    return dataset_dir

In [None]:
base_dir = Path("exercise_data")
dataset_name = "DIC-C2DH-HeLa"

# download_dataset returns the name of the directory, the data was downloaded to
dataset_dir = download_dataset(base_dir, dataset_name=dataset_name)

In [None]:
def sample_frames(frames, n, random_seed):
    rng = np.random.default_rng(random_seed)
    return tuple(rng.choice(frames, n, replace=False))


def create_dataset(
    dataset_dir, sequence, suffix="ST", n=None, random_seed=11, min_cells=0, ignore = None,
):
    """
    Load images and masks from a dataset directory.

    Loads n randomly selected images from a given sequence of a CTC dataset.
    """
    # create directory names
    img_dir = dataset_dir / sequence
    msk_dir = dataset_dir / f"{sequence}_{suffix}" / "SEG"
    if suffix == "ERR_SEG":
        msk_dir = msk_dir.parent
    # find out which frames have provided masks
    min_cells = min_cells or 0
    ignore = ignore or []
    frames = [
        re.search(r"\d{3}$", fp.stem).group()
        for fp in sorted(msk_dir.glob("*.tif"))
        if len(np.unique(skimage.io.imread(fp))) > min_cells
    ]

    if ignore is not None:
        frames = np.setdiff1d(frames, ignore)

    # use all available frames or randomly sample n frames that should be used/loaded
    frames_sampled = (
        tuple(frames)
        if (n is None) or (n > len(frames))
        else sample_frames(frames, n, random_seed)
    )

    # print info on the data that will be loaded
    if (n is not None) and (n > len(frames)):
        print(
            f"Number of requested frames ({n}) exceeds the number of frames in sequence ({len(frames)}). Loading all {len(frames)} frames."
        )
    elif n is not None:
        print(
            f"Loading {len(frames_sampled)} / {len(frames)} images from sequence {sequence} of {dataset_dir.name}."
        )
        print(f"Loaded frames {sorted(frames_sampled)}")

    # load images and masks
    images = [
        skimage.io.imread(fp)
        for fp in sorted(img_dir.glob("*.tif"))
        if fp.stem.endswith(frames_sampled)
    ]
    masks = [
        skimage.io.imread(fp)
        for fp in sorted(msk_dir.glob("*.tif"))
        if fp.stem.endswith(frames_sampled)
    ]

    return np.stack(images), np.stack(masks), frames_sampled




In [None]:
x_val, y_val, frames_val = create_dataset(dataset_dir, "01", suffix="ST", n = 5)
x_train, y_train, _ = create_dataset(dataset_dir, "01", suffix="ST", ignore = frames_val, n = 10)
x_test, y_test,_ = create_dataset(dataset_dir, "02", suffix="GT")


In [None]:
random_state = np.random.RandomState(0)
colors = [(0,0,0)] + [(random_state.random(),random_state.random(),random_state.random()) for _ in range(len(np.unique(y_train[0])))]
random_cmap = matplotlib.colors.LinearSegmentedColormap.from_list('random_cmap', colors)


f,ax = plt.subplots(1,2, figsize = (20,20))
ax[0].imshow(x_train[0], cmap = 'gray')
ax[0].set_title('Raw Image')
ax[0].axis('off')

ax[1].imshow(y_train[0], cmap = random_cmap,  interpolation = 'nearest')
ax[1].set_title('Ground Truth')
ax[1].axis('off')
plt.show()

# 2 - 2D U-Net

The U-Net Architecture has an encoder-decoder structure with long range skip-connection connecting the encoder with the decoder. In this section, first we will construct the individual layers required to finally construct the whole U-Net.

#### 2.1 U-Net Encoder

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size, padding, pooling, dropout):
        super(EncoderLayer, self).__init__()
        if pooling:
            self.pooling = nn.MaxPool2d(2)
        else:
            self.pooling = None

        self.block = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size, stride=1, padding=padding),
            nn.ReLU(),
            nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
            nn.Conv2d(ch_out, ch_out, kernel_size=kernel_size, stride=1, padding=padding),
            nn.ReLU(),
            nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        )

    def forward(self,x):
        if self.pooling is not None:
            x = self.pooling(x)
        x = self.block(x)
        return(x)
    

## Exercise 1: U-Net Decoder

Below you will find the decoder class. As you have learned in the lecture, there are different options for implementing the skip-connections. The code for the skip-connections by concatenation is already given. Based on this, it is now your task to complete missing code for

1) additive skip-connections
2) no skip-connections

The places where code is to be inserted is marked with `# your code goes here` 


In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, ch_in, ch_out, kernel_size, padding, dropout, skip_mode='concat', upsampling_mode='transpose', cropping=False):
        super(DecoderLayer, self).__init__()

        assert upsampling_mode in ['transpose', 'interpolate'], f'Upsampling has to be either "transpose" or "interpolate" but got "{upsampling_mode}"'
        assert skip_mode in ['concat', 'add', 'none'], f'Skip-connection has to be either "none", "add" or "concat" but got "{skip_mode}"'

        self.cropping = cropping
        self.skip_mode = skip_mode
        self.upsampling_mode = upsampling_mode

        if self.upsampling_mode == 'transpose':
            self.up = nn.ConvTranspose2d(ch_in, ch_out, kernel_size=2, stride = 2)
        else:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            #1x1 convolution equalizes the number of channels
            self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0)

        if self.skip_mode == 'concat':
            ch_hidden = ch_out + ch_out
        elif self.skip_mode == 'add':
            ch_hidden = ch_out
        elif self.skip_mode == 'none':
            ch_hidden = ch_out

        self.block = nn.Sequential(
            nn.Conv2d(ch_hidden, ch_out, kernel_size=kernel_size, stride=1, padding=padding),
            nn.ReLU(),
            nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
            nn.Conv2d(ch_out, ch_out, kernel_size=kernel_size, stride=1, padding=padding),
            nn.ReLU(),
            nn.Dropout(dropout) if dropout > 0 else nn.Identity()
        )

    def crop(self, x, cropping_size):
        return(x[:,:,cropping_size[0]:-cropping_size[0], cropping_size[1]:-cropping_size[1]])

    def forward(self,x, skip_features):
        if self.upsampling_mode == 'transpose':
            x = self.up(x)
        elif self.upsampling_mode == 'interpolate':
            x = self.up(x)
            x = self.conv(x)

        if self.cropping:
            cropping_size = (torch.tensor(skip_features.shape[2:]) - torch.tensor(x.shape[2:]))//2
            skip_features = self.crop(skip_features, cropping_size)
            
        if self.skip_mode == 'concat':
            x = self.block(torch.cat((x, skip_features), 1))
        elif self.skip_mode == 'add':
            # your code goes here
            pass
        elif self.skip_mode == 'none':
            # your code goes here
            pass
        
        return(x)

#### 2.3 U-Net Architecture

In [None]:
class UNet2d(nn.Module):
    def __init__(self, 
                 input_dim, 
                 output_dim, 
                 encoder_layer=EncoderLayer,
                 decoder_layer=DecoderLayer,
                 hidden_dims=[64,128,256,512,1024], 
                 kernel_size=3,
                 padding_mode='valid', 
                 skip_mode='concat', 
                 upsampling_mode='transpose', 
                 dropout=0,
                 ):
        
        super(UNet2d, self).__init__()

        assert len(hidden_dims) > 0, 'UNet2d requires at least one hidden layer'
        assert padding_mode in ['same', 'valid'], f'Padding mode has to be either "same" or "valid" but got "{padding_mode}"'

        self.padding_mode = padding_mode
        
        cropping = True if padding_mode == 'valid' else False
        padding = 0 if padding_mode == 'valid' else kernel_size//2

        #Assembling the encoder
        encoder = []
        for i in range(len(hidden_dims)):
            if i == 0:
                ch_in = input_dim
                ch_out = hidden_dims[i]
                encoder.append(encoder_layer(ch_in, ch_out, kernel_size=kernel_size, padding=padding, pooling=False, dropout=0))
            elif i == (len(hidden_dims) - 1):
                ch_in = hidden_dims[i-1]
                ch_out = hidden_dims[i]
                encoder.append(encoder_layer(ch_in, ch_out, kernel_size=kernel_size, padding=padding, pooling=True, dropout=dropout))
            else:
                ch_in = hidden_dims[i-1]
                ch_out = hidden_dims[i]
                encoder.append(encoder_layer(ch_in, ch_out, kernel_size=kernel_size, padding=padding, pooling=True, dropout=0))
        self.encoder = nn.ModuleList(encoder)

        #Assembling the decoder
        decoder = []

        #Reversing the order of the hidden dims, since the decoder reduces the number of channels
        hidden_dims_rev = hidden_dims[::-1]

        for i in range(len(hidden_dims_rev) - 1):
            ch_in = hidden_dims_rev[i]
            ch_out = hidden_dims_rev[i+1]
            decoder.append(decoder_layer(ch_in, ch_out, kernel_size=kernel_size, padding=padding, dropout=0, skip_mode=skip_mode, upsampling_mode=upsampling_mode, cropping=cropping))
        self.decoder = nn.ModuleList(decoder)

        #Creating final 1x1 convolution 
        self.final_conv = nn.Conv2d(hidden_dims[0], output_dim, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        #Forward pass of the encoder
        skip_features = []
        for encoder_layer in self.encoder:
            x = encoder_layer(x)
            skip_features.insert(0, x)

        #Removing bottleneck features from the feature list
        skip_features = skip_features[1:]

        #Forward pass of the decoder
        for i, decoder_layer in enumerate(self.decoder):
            skip = skip_features[i]
            x = decoder_layer(x, skip)

        #Performing the final 1x1 convolution
        x = self.final_conv(x)
        return(x)



#### 2.4 Weight Initialization

In [None]:
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
    if isinstance(m, nn.ConvTranspose2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')

# 3 - Data Slicing and Augmentation

#### 3.1 Data Slicing

In order to process arbitrarily large images, they can be sliced into overlapping patches smaller than the image itself. This sliding window approach reduces the computational load, since we do not need to process the whole image at once.

In [None]:
def get_slices(dataset_shape, patch_size, stride):
    x, y = dataset_shape
    p_x, p_y = patch_size
    s_x, s_y = stride

    #Generate indices in x-direction
    x_indices = np.arange(0, x - p_x + 1, s_x)
    if x_indices[-1] + p_x < x:
        x_indices = np.append(x_indices, x - p_x)

    #Generate indices in y-direction
    y_indices = np.arange(0, y - p_y + 1, s_y)
    if y_indices[-1] + p_y < y:
        y_indices = np.append(y_indices, y - p_y)

    slices = []
    for idx_x in x_indices:
        for idx_y in y_indices:
            slices.append((
                slice(idx_x, idx_x + p_x),
                slice(idx_y, idx_y + p_y)
            ))
    return(slices)

#### 3.2 Data Transforms

In order to bring the data into a format suited for pytorch, we need to define transformations which are called when loading the data.

In [None]:
class ToFloat():
    def __call__(self, x):
        return(x.astype(np.float32))
    
class ToTensor():
    def __call__(self, x):
        if len(x.shape) == 2:
            return(torch.from_numpy(x.copy()).unsqueeze(0))
        elif len(x.shape) == 3:
            return(torch.from_numpy(x.copy()))
        

class Normalize():
    def __init__(self, min=None, max =None, eps=1e-8):
        self.min = min
        self.max = max
        self.eps = eps

    def __call__(self, x):
        if self.min is not None and self.max is not None:
            return((x - self.min) / (self.max - self.min + self.eps))
        else:
            return((x - x.min()) / (x.max() - x.min() + self.eps))
        

#When visualizing the data, we saw that the 
#ground truth contains instance segmentation masks.
#This class transforms them into a binary mask
#which can be used for foreground/background segmentation
class Instance2Semantic():
    def __call__(self, x):
        return((x > 0).astype(int))
    



#### 3.3 Data Augmentation

Neural networks require large annotated datasets in order to perform well. However, acquiring annotated datasets for biomedical image segmentation typically requires expert knowledge and is therefore costly and time-consuming. One way to cope with the lack of annotated datasets is to artificially increase the dataset by augmenting the existing labeled samples. In this section, multiple classes for augmenting images during training will be introduced.

In [None]:
class RandomRotate90():
    def __init__(self, random_state, execution_probability=0.0):
        self.random_state = random_state
        self.execution_probability = execution_probability

    def __call__(self, x):
        if self.random_state.random() < self.execution_probability:
            k = self.random_state.randint(1,4)
            return(np.rot90(x,k))
        return(x)

class RandomContrast():
    def __init__(self, random_state, scale, execution_probability=0.0):
        self.random_state = random_state
        self.execution_probability = execution_probability
        self.scale = scale
    def __call__(self, x):
        if self.random_state.random() < self.execution_probability:
            lam = self.random_state.uniform(np.max((0, 1 - self.scale)), 1 + self.scale)
            return(np.clip(x * lam, 0, 255))
        return(x)

class RandomBrightness():
    def __init__(self,random_state, scale, execution_probability=0.0):
        self.random_state = random_state
        self.execution_probability = execution_probability
        self.scale = scale
    def __call__(self, x):
        if self.random_state.random() < self.execution_probability:
            lam = 255 * self.scale * self.random_state.uniform(-1,1)
            return(np.clip(x + lam, 0, 255))
        return(x)
    
class RandomElasticDeformation():
    def __init__(self, random_state, grid_size, sigma, order, execution_probability=0.0):
        self.random_state = random_state
        self.grid_size = grid_size
        self.sigma = sigma
        self.order = order
        self.execution_probability = execution_probability

    def __call__(self, x):
        if self.random_state.random() < self.execution_probability:
            h,w = x.shape
            
            dx = self.random_state.randn(*[self.grid_size, self.grid_size]) * self.sigma
            dy = self.random_state.randn(*[self.grid_size, self.grid_size]) * self.sigma

            dx = resize(dx, x.shape, preserve_range=True, order = 3)
            dy = resize(dy, x.shape, preserve_range=True, order = 3)

            idx_x, idx_y = np.meshgrid(np.arange(h), np.arange(w))

            coordinates = idx_y + dy, idx_x + dx

            return(scipy.ndimage.map_coordinates(x, coordinates, order = self.order, mode = 'reflect', prefilter=False))
        return(x)



In [None]:
img = x_train[0]

rot90 = RandomRotate90(random_state=np.random.RandomState(random_seed), execution_probability=1)(img)
contrast = RandomContrast(random_state=np.random.RandomState(random_seed), scale=0.5, execution_probability=1)(img)
brightness = RandomBrightness(random_state=np.random.RandomState(random_seed), scale=1, execution_probability=1)(img)
deform = RandomElasticDeformation(random_state=np.random.RandomState(random_seed), grid_size=3, sigma=25, order=3, execution_probability=1)(img)

plot_images(
    [img, 'Original image'],
    [rot90, 'Random 90° rotation'],
    [contrast, 'Random contrast'],
    [brightness, 'Random brightness'],
    [deform, 'Random elastic deformation']
)


## Exercise 2

Based on the existing classes for data augmentation, write a class for randomly flipping an image vertically or horizontally with a certain execution probability. Test your implementation by applying it to an image from the training dataset. Use the provided templates. 

(BONUS) Additionally, implement a class for randomly rotating an image in a pre-defined angle range.

HINT: You can use pre-implemented function like `np.flip` and `scipy.ndimage.rotate`. 

## Solution 2

In [None]:
class RandomFlip():
    def __init__(self, random_state, execution_probability=0.0):
        self.random_state = random_state
        self.execution_probability = execution_probability
    def __call__(self, x):
        if self.random_state.random() < self.execution_probability:
            # your code goes here
            pass
        return(x)

class RandomRotate():
    def __init__(self, random_state, angle_range, execution_probability=0.0, mode='mirror', order=0, reshape=False):
        self.random_state = random_state
        self.angle_range = angle_range #This is a list specifing the angle range, i.e. [angle_min, angle_max]
        self.execution_probability = execution_probability
        self.mode = mode
        self.order = order
        self.reshape = reshape

    def __call__(self, x):
        if self.random_state.random() < self.execution_probability:
            # your code goes here
            pass
        return(x)

In [None]:
flip = RandomFlip(random_state=np.random.RandomState(random_seed), execution_probability=1)(img)
rot = RandomRotate(random_state=np.random.RandomState(random_seed), angle_range=[-45,45], execution_probability=1, mode = 'constant')(img)

plot_images(
    [img, 'Original image'],
    [flip, 'Random horizontal/vertical flip'],
    [rot, 'Random rotation']
)

# 4 - Training, Validation and Testing

#### 4.1 Dataset

Next we need to define the dataset.

In [None]:
class UNetDataset(Dataset):
    def __init__(self, 
                 image,
                 label,
                 patch_size, 
                 stride, 
                 padding,
                 image_transformer,
                 label_transformer,
                 phase,
                 ):
        

        self.phase = phase

        self.image_transformer = image_transformer
        self.label_transformer = label_transformer

        self.image = image
        self.label = label

        self.image_shape = self.image.shape
        if padding is not None:
            self.image = np.pad(self.image, pad_width=((padding[0], padding[0]), (padding[1], padding[1])), mode='reflect')
            self.label = np.pad(self.label, pad_width=((padding[0], padding[0]), (padding[1], padding[1])), mode='reflect')

        self.slices = get_slices(self.image.shape, patch_size, stride)


    def __len__(self,):
        return(len(self.slices))

    def __getitem__(self, index):
        sl = self.slices[index]
        
        x = self.image[sl]
        y = self.label[sl]
        
        x = self.image_transformer(x)
        y = self.label_transformer(y)

        if self.phase == 'test':
            return(x, y, sl)
        else:
            return(x, y)
        
#This function is needed to form the batch from a list of samples
def custom_collate(batch):
    if isinstance(batch[0], torch.Tensor):
        return(torch.stack(batch, 0))
    elif isinstance(batch[0], slice):
        return(batch)
    elif isinstance(batch[0], tuple):
        return([custom_collate(obj) for obj in zip(*batch)])

        

def get_dataloader(x, y, batch_size, patch_size, stride, padding, image_transforms, label_transforms, phase):
    assert phase in ['train', 'validate', 'test'], f'Phase has to be either "train", "validate" or "test" but got "{phase}"'
    shuffle = True if phase == 'train' else False
    data = []
    for i in range(x.shape[0]):
        dataset = UNetDataset(x[i], y[i], patch_size, stride,padding, image_transforms, label_transforms, phase)
        data.append(dataset)
    dataloader = DataLoader(dataset=ConcatDataset(data), batch_size=batch_size, shuffle=shuffle, num_workers=0, collate_fn=custom_collate)
    return(dataloader)



#### 4.2 - Evaluation Metrics

For evaluation of the U-Net, we will use the Dice or IoU score.

In [None]:
def Dice(pred, gt):
    numerator = 2 * np.sum(pred * gt, axis = (1,2))
    denominator = np.sum(pred, axis = (1,2)) + np.sum(gt, axis = (1,2))
    return(np.mean(numerator / denominator))

def IoU(pred, gt):
    numerator = np.sum(pred * gt, axis = (1,2))
    denominator = np.sum(pred, axis = (1,2)) + np.sum(gt, axis = (1,2)) - np.sum(pred * gt, axis = (1,2))
    return(np.mean(numerator / denominator))

#### 4.3 - Training and Validation Loop

In [None]:
class UNetTrainer():
    def __init__(self, model, optimizer, lr_scheduler, loss_function, device):
        self.model = model.to(device)
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.loss_function = loss_function
        self.device = device
        self.total_epochs = 0
        self.total_iterations = 0

        self.training_losses = []
        self.validation_losses = []


    def train(self, epochs, dataloader_train, dataloader_val):
        print('Starting training...')
        self.model.train()
        for epoch in range(epochs):
            training_loss = 0
            for data in dataloader_train:
                input, target = data

                input = input.to(self.device)
                target = target.to(self.device)

                output = self.model(input)

                if self.model.padding_mode == 'valid':
                    cropping_size = (torch.tensor(input.shape[2:]) - torch.tensor(output.shape[2:]))//2
                    target = target[:,:,cropping_size[0]:-cropping_size[0], cropping_size[1]:-cropping_size[1]]


                loss = self.loss_function(output, target[:,0].long()) 

                training_loss += loss.item() 

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()     

                self.total_iterations += 1

            
            validation_loss = self.validate(dataloader_val)

            if self.lr_scheduler is not None:
                self.lr_scheduler.step(validation_loss)
                
            print(f'Epoch: [{epoch + 1}/{epochs}],', 
                  f'Total iterations: {self.total_iterations},', 
                  f'Training Loss: {training_loss / len(dataloader_train)},',
                  f'Validation Loss: {validation_loss/ len(dataloader_val)}',)

            self.total_epochs += 1

            self.training_losses.append(training_loss / len(dataloader_train))
            self.validation_losses.append(validation_loss / len(dataloader_val)) 

    def validate(self, dataloader):
        self.model.eval()
        validation_loss = 0
        for data in dataloader:
            input, target = data

            input = input.to(self.device)
            target = target.to(self.device)

            output = self.model(input)

            if self.model.padding_mode == 'valid':
                cropping_size = (torch.tensor(input.shape[2:]) - torch.tensor(output.shape[2:]))//2
                target = target[:,:,cropping_size[0]:-cropping_size[0], cropping_size[1]:-cropping_size[1]]

            loss = self.loss_function(output, target[:,0].long())

            validation_loss += loss.item()

        self.model.train()
        return(validation_loss)
    


#### 4.4 - Testing

In [None]:
def unpad(img, padding_size):
    if len(img.shape) == 3:
        return(img[:,padding_size[0]:-padding_size[0], padding_size[1]:-padding_size[1]])
    else:
        return(img[padding_size[0]:-padding_size[0], padding_size[1]:-padding_size[1]])

def predict(model, x_test, y_test, patch_size, stride, padding, image_transforms, label_transforms, eval_metric = None):
    dataloader = get_dataloader(x=x_test[None], y=y_test[None], batch_size=1, patch_size=patch_size, stride=stride, padding=padding, image_transforms=image_transforms, label_transforms=label_transforms, phase='test')
    image_size = dataloader.dataset.datasets[0].image_shape

    model.eval()

    padded_shape = (image_size[0] + 2 * padding[0], image_size[1] + 2 * padding[1])
    prediction = np.zeros((2,) + padded_shape)
    normalization = np.zeros(padded_shape)
    with torch.no_grad():
        for data in dataloader:
            input, _, current_patch_index = data

            input = input.to(device)

            output = model(input)


            for b in range(input.shape[0]):
                if input.shape[2] == output.shape[2] and input.shape[3] == output.shape[3]:
                    normalization[(current_patch_index[0][b], current_patch_index[1][b])] += 1
                    prediction[(slice(0,2,None),) + (current_patch_index[0][b], current_patch_index[1][b])] += output[b].detach().cpu().numpy()

                else:
                    cropping_size = (torch.tensor(input.shape[2:]) - torch.tensor(output.shape[2:]))//2 
                    padded_index = (
                        slice(current_patch_index[0][b].start + cropping_size[0], current_patch_index[0][b].stop - cropping_size[0]),
                        slice(current_patch_index[1][b].start + cropping_size[1], current_patch_index[1][b].stop - cropping_size[1])
                    )
                    normalization[padded_index] += 1
                    prediction[(slice(0,2,None),) + padded_index] += output[b].detach().cpu().numpy()

    model.train()

    if padding[0] != 0 and padding[1] != 0:
        prediction = unpad(prediction, padding)
        normalization = unpad(normalization, padding)

    result = (prediction / normalization).argmax(0)

    if eval_metric is not None:
        eval_score = eval_metric(result[None], (y_test>0)[None])
    else:
        eval_score = None
        
    return(result, eval_score)



# Experiments

In [None]:
#Defining the transformations and augmentations for the training images and labels. 
execution_probability = 0.1 #Probability that augmentations will be applied to the images
image_transforms_train = Compose([
    ToFloat(),
    RandomRotate90(random_state=np.random.RandomState(random_seed), execution_probability=execution_probability),
    RandomRotate(random_state=np.random.RandomState(random_seed), angle_range=[-10,10], execution_probability=execution_probability, order=2),
    RandomFlip(random_state=np.random.RandomState(random_seed), execution_probability=execution_probability),
    RandomElasticDeformation(random_state=np.random.RandomState(random_seed), grid_size=3, sigma=10, order=3, execution_probability=execution_probability),
    RandomContrast(random_state=np.random.RandomState(random_seed), scale=0.1, execution_probability=execution_probability),
    RandomBrightness(random_state=np.random.RandomState(random_seed), scale=0.1, execution_probability=execution_probability),
    Normalize(min=0, max=255),
    ToTensor()
])
label_transforms_train = Compose([
    Instance2Semantic(),
    RandomRotate90(random_state=np.random.RandomState(random_seed), execution_probability=execution_probability),
    RandomRotate(random_state=np.random.RandomState(random_seed), angle_range=[-10,10], execution_probability=execution_probability, order=0),
    RandomFlip(random_state=np.random.RandomState(random_seed), execution_probability=execution_probability),
    RandomElasticDeformation(random_state=np.random.RandomState(random_seed), grid_size=3, sigma=10, order=1, execution_probability=execution_probability),
    ToTensor()
])



#Defining the transformations and augmentations for the validation images and labels. 
image_transforms_val = Compose([
    ToFloat(),
    Normalize(min=0, max=255),
    ToTensor()
])
label_transforms_val = Compose([
    Instance2Semantic(),
    ToTensor()
])



#Defining the transformations and augmentations for the test images and labels. 
image_transforms_test = Compose([
    ToFloat(),
    Normalize(min=0, max=255),
    ToTensor()
])
label_transforms_test = Compose([
    Instance2Semantic(),
    ToTensor()
])


In [None]:
#Training of the original U-Net

#Network parameters
input_channel = 1
output_channel = 2
hidden_dims = [32,64,128,256,512]
kernel_size = 3                         
padding_mode = 'valid'                  # possible padding modes: 'same', 'valid'
skip_mode = 'concat'                    # possible skip modes: 'none', 'add', 'concat'
upsampling_mode = 'transpose'           # possible upsampling modes: 'interpolate', 'transpose'
dropout = 0.1                           # note: dropout is only applied to the bottleneck features
encoder_layer = EncoderLayer            
decoder_layer = DecoderLayer            

#Dataset parameters
batch_size = 4                         
patch_size_original_train = [332,332]   # patches cannot be larger than the image itself        
stride_original_train = [128,128]       # stride should be smaller than the patches 
padding_original_train = [0,0]          # when using 'valid' convolution, padding can be used to cope with the border problem

#Training parameters
loss_function = nn.CrossEntropyLoss()   
eval_metric = Dice
learning_rate = 1e-3

#Preparing dataloader
dataloader_train = get_dataloader(x=x_train, y=y_train, batch_size=batch_size, patch_size=patch_size_original_train, stride=stride_original_train, padding=padding_original_train, image_transforms=image_transforms_train, label_transforms=label_transforms_train, phase='train')
dataloader_val = get_dataloader(x=x_val, y=y_val, batch_size=batch_size, patch_size=patch_size_original_train, stride=stride_original_train, padding=padding_original_train, image_transforms=image_transforms_val, label_transforms=label_transforms_val, phase='validate')

#Initializing the U-Net
unet = UNet2d(input_channel,output_channel,encoder_layer,decoder_layer, hidden_dims, kernel_size, padding_mode, skip_mode, upsampling_mode, dropout)
unet.apply(weights_init)
optimizer = optim.SGD(unet.parameters(), lr=learning_rate, momentum=0.99)
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=50, verbose=True)

trainer_original = UNetTrainer(unet, optimizer, lr_scheduler, loss_function, device)

In [None]:
epochs = 50

trainer_original.train(epochs, dataloader_train, dataloader_val)

plot_results(
    [trainer_original.training_losses, 'Training loss'],
    [trainer_original.validation_losses, 'Validation loss']
)

In [None]:
patch_size_original_test= [332,332]
stride_original_test = [128,128]
padding_original_test = [92,92]

test_scores_original = []

for t in range(len(x_test)):
    prediction, dice_score = predict(trainer_original.model, x_test[t], y_test[t], patch_size=patch_size_original_test, stride=stride_original_test, padding=padding_original_test, image_transforms = image_transforms_test, label_transforms=label_transforms_test, eval_metric=eval_metric)
    test_scores_original.append(dice_score)

print(f'Mean dice score for the test dataset using the original U-Net: {np.mean(test_scores_original)}')

In [None]:
idx = 0

prediction_original, dice_score = predict(trainer_original.model, x_test[idx], y_test[idx], patch_size=patch_size_original_test, stride=stride_original_test, padding=padding_original_test, image_transforms = image_transforms_test, label_transforms=label_transforms_test, eval_metric=eval_metric)

plot_images(
    [x_test[idx], 'Original image'],
    [prediction_original, f'Prediction original U-Net \n Dice = {np.round(dice_score,4)}'],
    [y_test[idx]>0, 'Ground truth']
)

## Exercise 3

The U-Net we have introduced so far is very much based on the original U-Net proposed in 2015. There are several options to improve the performance of the network. Try the following steps. You can either perform each step at a time to see the improvements individually, or change everything at once. 

1) Add batch normalization layers to the encoder and decoder (since we are using a rather small training dataset, set `track_running_stats=False` for improved results. This way the statistics for the batch normalization are calculated for each batch individually)
2) For convolution use 'same' instead of 'valid'
3) Replace the transposed convolution with bilinear interpolation
4) Replace SGD optimizer with a more advanced optimizer like Adam 

Report your findings.

## Solution 3

In [None]:
class EncoderLayerBN(EncoderLayer):
    def __init__(self, ch_in, ch_out, kernel_size, padding, pooling, dropout):
        super(EncoderLayerBN, self).__init__(ch_in, ch_out, kernel_size, padding, pooling, dropout)

        # your code goes here
        self.block = nn.Sequential()


class DecoderLayerBN(DecoderLayer):
    def __init__(self, ch_in, ch_out, kernel_size, padding, dropout, skip_mode='concat', upsampling_mode='transpose', cropping=False):
        super(DecoderLayerBN, self).__init__(ch_in, ch_out, kernel_size, padding, dropout, skip_mode, upsampling_mode, cropping)

        if self.skip_mode == 'concat':
            ch_hidden = ch_out + ch_out
        elif self.skip_mode == 'add':
            ch_hidden = ch_out
        elif self.skip_mode == 'none':
            ch_hidden = ch_out

        # your code goes here
        self.block = nn.Sequential()


In [None]:
#Training of the adapted U-Net

#Network parameters
input_channel = 1
output_channel = 2
hidden_dims = [32,64,128,256,512]
kernel_size = 3                         
padding_mode = # your code goes here
skip_mode = 'concat'
upsampling_mode = # your code goes here
dropout = 0
encoder_layer = EncoderLayerBN
decoder_layer = DecoderLayerBN

#Dataset parameters
batch_size = 4
patch_size_train = [256,256]
stride_train = [128,128]       
padding_train = [0,0]

#Training parameters
loss_function = nn.CrossEntropyLoss()
eval_metric = Dice
learning_rate = 1e-4


#Preparing dataloader
dataloader_train = get_dataloader(x=x_train, y=y_train, batch_size=batch_size, patch_size=patch_size_train, stride=stride_train, padding=padding_train, image_transforms=image_transforms_train, label_transforms=label_transforms_train, phase='train')
dataloader_val = get_dataloader(x=x_val, y=y_val, batch_size=batch_size, patch_size=patch_size_train, stride=stride_train, padding=padding_train, image_transforms=image_transforms_val, label_transforms=label_transforms_val, phase='validate')

#Initializing the U-Net
unet = UNet2d(input_channel,output_channel,encoder_layer, decoder_layer, hidden_dims, kernel_size, padding_mode, skip_mode, upsampling_mode, dropout)
optimizer = # your code goes here
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=20, verbose=True)

trainer_adapted = UNetTrainer(unet, optimizer, lr_scheduler, loss_function, device)

In [None]:
epochs = 50
trainer_adapted.train(epochs, dataloader_train, dataloader_val)


plot_results(
    [trainer_original.training_losses, 'Training loss original U-Net'],
    [trainer_adapted.training_losses, 'Training loss adapted U-Net']
)

plot_results(
    [trainer_original.validation_losses, 'Validation loss original U-Net'],
    [trainer_adapted.validation_losses, 'Validation loss adapted U-Net']
)


In [None]:
patch_size_test = [256,256]
stride_test = [128,128]
padding_test = [0,0]

test_scores_adapted = []

for t in range(len(x_test)):
    prediction, dice_score = predict(trainer_adapted.model, x_test[t], y_test[t], patch_size=patch_size_test, stride=stride_test, padding=padding_test, image_transforms = image_transforms_test, label_transforms=label_transforms_test, eval_metric = eval_metric)
    test_scores_adapted.append(dice_score)

print(f'Mean dice score for the test dataset using the original U-Net: {np.mean(test_scores_original)}')
print(f'Mean dice score for the test dataset using the adapted U-Net: {np.mean(test_scores_adapted)}')

In [None]:
idx = 0

prediction_original, dice_score_original = predict(trainer_original.model, x_test[idx], y_test[idx], patch_size=patch_size_original_test, stride=stride_original_test, padding=padding_original_test, image_transforms = image_transforms_test, label_transforms=label_transforms_test, eval_metric=eval_metric)
prediction_adapted, dice_score_adapted = predict(trainer_adapted.model, x_test[idx], y_test[idx], patch_size=patch_size_test, stride=stride_test, padding=padding_test, image_transforms = image_transforms_test, label_transforms=label_transforms_test, eval_metric=eval_metric)

plot_images(
    [x_test[idx], 'Original Image'],
    [prediction_original, f'Prediction original U-Net \n Dice = {np.round(dice_score_original, 4)}'],
    [prediction_adapted, f'Prediction adapted U-Net \n Dice = {np.round(dice_score_adapted, 4)}'],
    [y_test[idx]>0, 'Ground truth']
)

## Exercise 4

In order to see the impact of data augmentation, train the same model without applying any augmentation techniques. What do you see?

## Solution 4

In [None]:
# your code goes here
image_transforms_train_no_aug = Compose([])

# your code goes here
label_transforms_train_no_aug = Compose([])

In [None]:
#Training of the adapted U-Net without data augmentation

#Network parameters
input_channel = 1
output_channel = 2
hidden_dims = [32,64,128,256,512]
kernel_size = 3                         
padding_mode = 'same'
skip_mode = 'concat'
upsampling_mode = 'interpolate'
dropout = 0
encoder_layer = EncoderLayerBN          
decoder_layer = DecoderLayerBN          

#Dataset parameters
batch_size = 4                          
patch_size_train = [256,256]
stride_train = [128,128]
padding_train = [0,0]

#Training parameters
loss_function = nn.CrossEntropyLoss()
eval_metric = Dice
learning_rate = 1e-4


#Preparing dataloader
dataloader_train = get_dataloader(x=x_train, y=y_train, batch_size=batch_size, patch_size=patch_size_train, stride=stride_train, padding=padding_train, image_transforms=image_transforms_train_no_aug, label_transforms=label_transforms_train_no_aug, phase='train')
dataloader_val = get_dataloader(x=x_val, y=y_val, batch_size=batch_size, patch_size=patch_size_train, stride=stride_train, padding=padding_train, image_transforms=image_transforms_val, label_transforms=label_transforms_val, phase='validate')

#Initializing the U-Net
unet = UNet2d(input_channel,output_channel,encoder_layer, decoder_layer, hidden_dims, kernel_size, padding_mode, skip_mode, upsampling_mode, dropout)
optimizer = optim.Adam(unet.parameters(), lr = learning_rate)
lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=20, verbose=True)

trainer_adapted_no_aug = UNetTrainer(unet, optimizer, lr_scheduler, loss_function, device)

In [None]:
epochs = 50
trainer_adapted_no_aug.train(epochs, dataloader_train, dataloader_val)


plot_results(
    [trainer_adapted.training_losses, 'Training loss adapted U-Net'],
    [trainer_adapted_no_aug.training_losses, 'Training loss adapted U-Net w/o augmentation']
)

plot_results(
    [trainer_adapted.validation_losses, 'Validation loss adapted U-Net'],
    [trainer_adapted_no_aug.validation_losses, 'Validation loss adapted U-Net w/o augmentation']
)

In [None]:
patch_size_test = [256,256]
stride_test = [128,128]
padding_test = [0,0]

test_scores_adapted_no_aug = []

for t in range(len(x_test)):
    prediction, dice_score = predict(trainer_adapted_no_aug.model, x_test[t], y_test[t], patch_size=patch_size_test, stride=stride_test, padding=padding_test, image_transforms = image_transforms_test, label_transforms=label_transforms_test, eval_metric=eval_metric)
    test_scores_adapted_no_aug.append(dice_score)

print(f'Mean dice score for the test dataset using the adapted U-Net with data augmentation: {np.mean(test_scores_adapted)}')
print(f'Mean dice score for the test dataset using the adapted U-Net without augmentation: {np.mean(test_scores_adapted_no_aug)}')

In [None]:
idx = 0

prediction_adapted, dice_score_adapted = predict(trainer_adapted.model, x_test[idx], y_test[idx], patch_size=patch_size_test, stride=stride_test, padding=padding_test, image_transforms = image_transforms_test, label_transforms=label_transforms_test, eval_metric=eval_metric)
prediction_adapted_no_aug, dice_score_adapted_no_aug = predict(trainer_adapted_no_aug.model, x_test[idx], y_test[idx], patch_size=patch_size_test, stride=stride_test, padding=padding_test, image_transforms = image_transforms_test, label_transforms=label_transforms_test, eval_metric=eval_metric)

plot_images(
    [x_test[idx], 'Original image'],
    [prediction_adapted, f'Prediction adapted U-Net \n Dice = {np.round(dice_score_adapted, 4)}'],
    [prediction_adapted_no_aug, f'Prediction adapted U-Net w/o augmentation \n Dice = {np.round(dice_score_adapted_no_aug, 4)}'],
    [y_test[idx]>0, 'Ground truth']
)

## Exercise 5

In [Exercise 1](#exercise-1-u-net-decoder), besides concatenation you implemented two additional options for the skip-connections. Train a network with one of these options.

## Solution 5

In [None]:
#Training of the adapted U-Net without skip-connections

#Network parameters
input_channel = 1
output_channel = 2
hidden_dims = [32,64,128,256,512]
kernel_size = 3                         
padding_mode = 'same'                   
skip_mode = # your code goes here                     
upsampling_mode = 'interpolate'         
dropout = 0                             
encoder_layer = EncoderLayerBN          
decoder_layer = DecoderLayerBN          

#Dataset parameters
batch_size = 4                          
patch_size_train = [256,256]            
stride_train = [128,128]                
padding_train = [0,0]                   

#Training parameters
loss_function = nn.CrossEntropyLoss()
eval_metric = Dice
learning_rate = 1e-4


#Preparing dataloader
dataloader_train = get_dataloader(x=x_train, y=y_train, batch_size=batch_size, patch_size=patch_size_train, stride=stride_train, padding=padding_train, image_transforms=image_transforms_train, label_transforms=label_transforms_train, phase='train')
dataloader_val = get_dataloader(x=x_val, y=y_val, batch_size=batch_size, patch_size=patch_size_train, stride=stride_train, padding=padding_train, image_transforms=image_transforms_val, label_transforms=label_transforms_val, phase='validate')

#Initializing the U-Net
unet = UNet2d(input_channel,output_channel,encoder_layer, decoder_layer, hidden_dims, kernel_size, padding_mode, skip_mode, upsampling_mode, dropout)
optimizer = optim.Adam(unet.parameters(), lr = learning_rate)
lr_scheduler = ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.1, patience = 20, verbose = True)

trainer_adapted_no_skip = UNetTrainer(unet, optimizer, lr_scheduler, loss_function, device)

In [None]:
epochs = 50
trainer_adapted_no_skip.train(epochs, dataloader_train, dataloader_val)


plot_results(
    [trainer_adapted.training_losses, 'Training loss adapted U-Net'],
    [trainer_adapted_no_skip.training_losses, 'Training loss adapted U-Net w/o skip_connections']
)

plot_results(
    [trainer_adapted.validation_losses, 'Validation loss adapted U-Net'],
    [trainer_adapted_no_skip.validation_losses, 'Validation loss adapted U-Net w/o skip-connections']
)

In [None]:
patch_size_test = [256,256]
stride_test = [128,128]
padding_test = [0,0]

test_scores_adapted_no_skip = []

for t in range(len(x_test)):
    prediction, dice_score = predict(trainer_adapted_no_skip.model, x_test[t], y_test[t], patch_size=patch_size_test, stride=stride_test, padding=padding_test, image_transforms = image_transforms_test, label_transforms=label_transforms_test, eval_metric=eval_metric)
    test_scores_adapted_no_skip.append(dice_score)

print(f'Mean dice score for the test dataset using the adapted U-Net with concatenation skip-connections: {np.mean(test_scores_adapted)}')
print(f'Mean dice score for the test dataset using the adapted U-Net with "{skip_mode}" skip-connections: {np.mean(test_scores_adapted_no_skip)}')

In [None]:
idx = 0

prediction_adapted, eval_score_adapted = predict(trainer_adapted.model, x_test[idx], y_test[idx], patch_size=patch_size_test, stride=stride_test, padding=padding_test, image_transforms = image_transforms_test, label_transforms=label_transforms_test, eval_metric=eval_metric)
prediction_adapted_no_skip, eval_score_adapted_no_skip = predict(trainer_adapted_no_skip.model, x_test[idx], y_test[idx], patch_size=patch_size_test, stride=stride_test, padding=padding_test, image_transforms = image_transforms_test, label_transforms=label_transforms_test, eval_metric=eval_metric)


plot_images(
    [x_test[idx], 'Original image'],
    [prediction_adapted, f'Prediction adapted U-Net \n Dice = {np.round(eval_score_adapted, 4)}'],
    [prediction_adapted_no_skip, f'Prediction adapted U-Net with {skip_mode} skip-connections \n Dice = {np.round(eval_score_adapted_no_skip, 4)}'],
    [y_test[idx]>0, 'Ground truth']
)