Skip to content
This repository has been archived by the owner. It is now read-only.
Switch branches/tags
Go to file
Cannot retrieve contributors at this time

Segmentation in 2D using U-Nets with Delira - A very short introduction

Author: Justus Schock, Alexander Moriz

Date: 17.12.2018

This Example shows how use the U-Net implementation in Delira with PyTorch.

Let's first setup the essential hyperparameters. We will use delira's Parameters-class for this:

logger = None
import torch
from import Parameters
params = Parameters(fixed_params={
    "model": {
        "in_channels": 1,
        "num_classes": 4
    "training": {
        "batch_size": 64, # batchsize to use
        "num_epochs": 10, # number of epochs to train
        "optimizer_cls": torch.optim.Adam, # optimization algorithm to use
        "optimizer_params": {'lr': 1e-3}, # initialization parameters for this algorithm
        "losses": {"CE": torch.nn.CrossEntropyLoss()}, # the loss function
        "lr_sched_cls": None,  # the learning rate scheduling algorithm to use
        "lr_sched_params": {}, # the corresponding initialization parameters
        "metrics": {} # and some evaluation metrics

Since we did not specify any metric, only the CrossEntropyLoss will be calculated for each batch. Since we have a classification task, this should be sufficient. We will train our network with a batchsize of 64 by using Adam as optimizer of choice.

Logging and Visualization

To get a visualization of our results, we should monitor them somehow. For logging we will use Visdom. To start a visdom server you need to execute the following command inside an environment which has visdom installed:

visdom -port=9999

This will start a visdom server on port 9999 of your machine and now we can start to configure our logging environment. To view your results you can open http://localhost:9999 in your browser.

from trixi.logger import PytorchVisdomLogger
from delira.logging import TrixiHandler
import logging

logger_kwargs = {
    'name': 'ClassificationExampleLogger', # name of our logging environment
    'port': 9999 # port on which our visdom server is alive

logger_cls = PytorchVisdomLogger

# configure logging module (and root logger)
                    handlers=[TrixiHandler(logger_cls, **logger_kwargs)])

# derive logger from root logger
# (don't do `logger = logging.Logger("...")` since this will create a new
# logger which is unrelated to the root logger
logger = logging.getLogger("Test Logger")

Since a single visdom server can run multiple environments, we need to specify a (unique) name for our environment and need to tell the logger, on which port it can find the visdom server.

Data Praparation


Next we will create a small train and validation set (in this case they will be the same to show the overfitting capability of the UNet).

Our data is a brain MR-image thankfully provided by the FSL in their introduction.

We first download the data and extract the T1 image and the corresponding segmentation:

from io import BytesIO
from zipfile import ZipFile
from urllib.request import urlopen

resp = urlopen("")
zipfile = ZipFile(BytesIO(
#zipfile_list = zipfile.namelist()
img_file = zipfile.extract("ExBox3/T1_brain.nii.gz")
mask_file = zipfile.extract("ExBox3/T1_brain_seg.nii.gz")

Now, we load the image and the mask (they are both 3D), convert them to a 32-bit floating point numpy array and ensure, they have the same shape (i.e. that for each voxel in the image, there is a voxel in the mask):

import SimpleITK as sitk
import numpy as np

# load image and mask
img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))
img = img.astype(np.float32)
mask = mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))
mask = mask.astype(np.float32)

assert mask.shape == img.shape

By querying the unique values in the mask, we get the following:


This means, there are 4 classes (background and 3 types of tissue) in our sample.

Since we want to do a 2D segmentation, we extract a single slice out of the image and the mask (we choose slice 100 here) and plot it:

import matplotlib.pyplot as plt

# load single slice
img_slice = img[:, :, 100]
mask_slice = mask[:, :, 100]

# plot slices
plt.figure(1, figsize=(15,10))
plt.imshow(img_slice, cmap="gray")
plt.colorbar(fraction=0.046, pad=0.04)
plt.imshow(mask_slice, cmap="gray")
plt.colorbar(fraction=0.046, pad=0.04)

To load the data, we have to use a Dataset. The following defines a very simple dataset, accepting an image slice, a mask slice and the number of samples. It always returns the same sample until num_samples samples have been returned.

from delira.data_loading import AbstractDataset

class CustomDataset(AbstractDataset):
    def __init__(self, img, mask, num_samples=1000):
        super().__init__(None, None, None, None) = {"data": img.reshape(1, *img.shape), "label": mask.reshape(1, *mask.shape)}
        self.num_samples = num_samples

    def __getitem__(self, index):

    def __len__(self):
        return self.num_samples

Now, we can finally instantiate our datasets:

dataset_train = CustomDataset(img_slice, mask_slice, num_samples=10000)
dataset_val = CustomDataset(img_slice, mask_slice, num_samples=1)


For Data-Augmentation we will apply a few transformations:

from batchgenerators.transforms import RandomCropTransform, \
                                        ContrastAugmentationTransform, Compose
from batchgenerators.transforms.spatial_transforms import ResizeTransform
from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform

transforms = Compose([
    RandomCropTransform(150, label_key="label"), # Perform Random Crops of Size 150 x 150 pixels
    ResizeTransform(224, label_key="label"), # Resample these crops back to 224 x 224 pixels
    ContrastAugmentationTransform(), # randomly adjust contrast
    MeanStdNormalizationTransform(mean=[img_slice.mean()], std=[img_slice.std()])]) # use concrete values since we only have one sample (have to estimate it over whole dataset otherwise)

With these transformations we can now wrap our datasets into datamanagers:

from delira.data_loading import DataManager, SequentialSampler, RandomSampler

manager_train = DataManager(dataset_train, params.nested_get("batch_size"),

manager_val = DataManager(dataset_val, params.nested_get("batch_size"),


After we have done that, we can finally specify our experiment and run it. We will therfore use the already implemented UNet2dPytorch:

import warnings
warnings.simplefilter("ignore", UserWarning) # ignore UserWarnings raised by dependency code
warnings.simplefilter("ignore", FutureWarning) # ignore FutureWarnings raised by dependency code

from import PyTorchExperiment
from import create_optims_default_pytorch
from delira.models.segmentation import UNet2dPyTorch

if logger is not None:"Init Experiment")
experiment = PyTorchExperiment(params, UNet2dPyTorch,
                               gpu_ids=[0], mixed_precision=True)

model =, manager_val)

See Also

For a more detailed explanation have a look at * the introduction tutorial * the classification example * the 3d segmentation example * the generative adversarial example