# Semantic Segmentation (2D)


This exercise will demonstrate a very simple approach to perform *semantic segmentation* with convolutional neural networks. *Semantic segmentation* means, we aim to assign every pixel of the input image one of several different classes (background, cell interior, cell boundary) without distinguishing objects of the same class.

![](_images/task_semantic.png)

## Setup and imports

In [None]:
import numpy as np
import matplotlib
matplotlib.rcParams["image.interpolation"] = None
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from glob import glob
from tqdm import tqdm
from datetime import datetime
from tifffile import imread
from pathlib import Path
import skimage
from skimage.segmentation import find_boundaries

import tensorflow as tf


from csbdeep.internals.nets import common_unet, custom_unet
from csbdeep.internals.blocks import unet_block, resnet_block

## Data


First we download some sample images and corresponding masks

In [None]:
from csbdeep.utils import download_and_extract_zip_file, normalize

download_and_extract_zip_file(
    url       = 'https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip',
    targetdir = 'data',
    verbose   = 1,
)

Next we load the data, generate from the annotation masks background/foreground/cell border masks, and crop out a central patch (this is just for simplicity, as it makes our life a bit easier when all images have the same shape)

In [None]:

def crop(u,shape=(256,256)):
    """Crop central region of given shape"""
    return u[tuple(slice((s-m)//2,(s-m)//2+m) for s,m in zip(u.shape,shape))]

def to_3class_label(lbl, onehot=True):
    """Convert instance labeling to background/inner/outer mask"""
    b = find_boundaries(lbl,mode='outer')
    res = (lbl>0).astype(np.uint8)
    res[b] = 2
    if onehot:
        res = tf.keras.utils.to_categorical(res,num_classes=3).reshape(lbl.shape+(3,))
    return res

# load and crop out central patch (for simplicity)
X   = [normalize(crop(imread(x))) for x in sorted(glob('data/dsb2018/train/images/*.tif'))]
Y   = [to_3class_label(crop(imread(y))) for y in sorted(glob('data/dsb2018/train/masks/*.tif'))]

# convert to numpy arrays
X, Y = np.expand_dims(np.stack(X),-1), np.stack(Y)

In [None]:
# plot an example image
i = 3
fig, (a0,a1) = plt.subplots(1,2,figsize=(15,5))
a0.imshow(X[i,...,0],cmap='gray');  
a0.set_title('input image')
a1.imshow(Y[i]);                    
a1.set_title('segmentation mask')
fig.suptitle("Example")
None;

<div class="alert alert-block alert-info"><h2>Exercise</h2> 
    

1)  Plot some more images. What kind of data is shown? How variable is it? Do the segmentation masks look reasonable? 
        
</div>


We now split the training data into ~ 80/20 training and validation data

In [None]:
from csbdeep.data import shuffle_inplace

# shuffle data
shuffle_inplace(X, Y, seed=0)

# split into 80% training and 20% validation images
n_val = len(X) // 5
def split_train_val(a):
    return a[:-n_val], a[-n_val:]
X_train,       X_val       = split_train_val(X)
Y_train,       Y_val       = split_train_val(Y)

print(f'training   data: {len(X_train)} images and {len(Y_train)} masks')
print(f'validation data: {len(X_val)} images and {len(Y_val)} masks')

## Building a UNet 

We now will construct a very simple 3-class segmentation model, for which we will use a UNet 

<img width=400 src="https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png"></img>

For the actual implementation, we will make use of the function `custom_unet` from `csbdeep.internals.nets`

In [None]:
from csbdeep.internals.nets import custom_unet

In [None]:
model = custom_unet(input_shape=(None,None,1), n_channel_out=3, kernel_size=(3,3), pool_size=(2,2), 
                    n_filter_base=32, last_activation='softmax')

model.summary()

<div class="alert alert-block alert-info"><h2>Exercise</h2> 
    

1) What is the intuition about the gray "skip connections"? 
    
2) Apply the (untrained) model on a example image (with `model.predict`). What is the output? How is it normalized?


</div>

### Compiling the model 

We now will compile the model, i.e. deciding on a loss function and a optimizer.

As we have a classification task with multiple output classes, we will use a simple `categorical_crossentropy` loss as loss function. Furthermore, `Adam` with the a learning rate on the order of `1e-4 - 1e-3` is a safe default (General reading tip: http://karpathy.github.io/2019/04/25/recipe/ :)

In [None]:
model.compile(loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4))

Before we train the model, we define some callbacks that will monitor the training loss etc

In [None]:
from csbdeep.utils.tf import CARETensorBoardImage

timestamp = datetime.now().strftime("%d-%H:%M:%S")
logdir = Path(f'models/1_semantic_segmentation_2D/{timestamp}')
logdir.mkdir(parents=True, exist_ok=True)
callbacks = []
callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=logdir))
callbacks.append(CARETensorBoardImage(model=model, data=(X_val,Y_val),
                            log_dir=logdir/'images',
                            n_images=3))


### Ready to train!

In [None]:
# Please someone let me know how to start tensorboard :)

In [None]:
model.fit(X_train, Y_train, validation_data=(X_val,Y_val),
         epochs=100, callbacks=callbacks, verbose=1)

### Predict

In [None]:
i=1

img  = X_val[i,..., 0]
mask = Y_val[i]
plt.imshow(img)


In [None]:
mask_pred = model.predict(img[np.newaxis,...,np.newaxis])[0]
mask_pred.shape

In [None]:
from skimage.measure import label

# threshold inner (green) and find connected components
lbl_pred = label(mask_pred[...,1] > 0.7)

fig, ((a0,a1),(b0,b1)) = plt.subplots(2,2,figsize=(15,10))
a0.imshow(img,cmap='gray');       
a0.set_title('input image')
a1.imshow(mask);                  
a1.set_title('GT segmentation mask')
b0.axis('off')
b0.imshow(lbl_pred,cmap='tab20'); 
b0.set_title('label image (prediction)')
b1.imshow(mask_pred);             
b1.set_title('segmentation mask (prediction)')
fig.suptitle("Example")
None;

<div class="alert alert-block alert-info"><h2>Exercise</h2> 
    

Can you spot the label image mistakes? What could be the reason?
    
 

</div>