# Setting up an image segmentation

In this example, we will use the tf.data package to preprocess training data and then train a UNet on a binary segmentation task.

In [None]:
import os

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)

### Using the tf.data.Dataset to transform inputs

We'll first need some utility function to read the data. We load everything as numpy array and then
create the tf.data pipeline from it. Note that for a large dataset, reading the data should be integrated
into the pipeline.

In [None]:
from skimage.external.tifffile import imread
from fnmatch import fnmatch
from filecmp import dircmp

def normalize(img):
    '''
    '''
    return img.astype('float32') / 255.

def load_data(folder, raw_subfolder, mask_subfolder, pattern):
    '''loads pairs of images and annotations.
    
    '''
    def _get_matching_filenames(first, second):
        '''convenience generator to get files with matching filenames.
        
        '''
        for fname in dircmp(first, second).common_files:
            if fnmatch(fname, pattern):
                yield os.path.join(first, fname), os.path.join(second, fname)
                

    imgs, masks = zip(*[(normalize(imread(raw_path)), imread(mask_path) >= 1) 
                        for raw_path, mask_path in _get_matching_filenames(os.path.join(folder, raw_subfolder),
                                                                           os.path.join(folder, mask_subfolder))])
    
    return (np.asarray(imgs)[..., None], np.asarray(masks)[..., None])

Next, we define some transformations that we want to apply to the data before feeding it to the model.
Typical transforms are taking crops in order to train on patches and data augmentations (like adding noise to the input image).

In [None]:
def random_crop(patch_size):
    '''returns the patch sampling function that takes the same
    random patch for all tensors in a given dictionary.

    If you just want to take a random patch from a *single*
    tensor, you should probably use tensorflow.image.random_crop

    '''
    def _cropper(*inputs):
        '''expects a dictionary, list or tuple of tensors of
        identical shape as inputs.

        '''
        with tf.name_scope('random_patch'):

            shape = tf.shape(inputs[0])
            size = tf.convert_to_tensor(patch_size, name='patch_size')
            limit = shape - size + 1
            offset = tf.random.uniform(
                tf.shape(shape),
                dtype=tf.int32,
                maxval=tf.int32.max,
            ) % limit


            return tuple(tf.slice(value, offset, size)
                         for value in inputs)

    return _cropper


def random_axis_flip(axis, flip_prob=0.5):
    '''reverses axis with probability threshold for all given inputs.
    
    '''
    def _flipper(*inputs):
        '''
        '''
        draw_prob = tf.random.uniform(
            shape=[], minval=0, maxval=1, dtype=tf.float32)

        return tuple(tf.cond(
                draw_prob <= flip_prob,
                lambda: tf.reverse(val, [axis]),  # pylint: disable = W0640
                lambda: val)                      # pylint: disable = W0640
            for val in inputs)

    return _flipper


def gaussian_noise(noise_mu, noise_sigma, key):
    '''adds gaussian noise to the given tensor.
    Noise level (sigma) are sampled for each call from the given
    noise_mu and noise_sigma.
    '''

    def _distorter(*inputs):
        '''
        '''
        inputs = list(inputs)
        sigma = tf.maximum(
            0., tf.random.normal(shape=[], mean=noise_mu, stddev=noise_sigma))

        image = inputs[key]
        noise = tf.random.normal(
            shape=tf.shape(image), mean=0, stddev=sigma)
        inputs[key] = image + noise
        return inputs

    return _distorter

In [None]:
def create_dataset(*args, patch_size=None, batch_size=1, patches_per_image=1, augmentations=None, **kwargs):
    '''
    '''
    # some data pipeline parameters
    shuffle_buffer = 10
    num_parallel_calls = tf.data.experimental.AUTOTUNE
    
    # actually create the dataset and transform it
    dataset = tf.data.Dataset.from_tensor_slices(load_data(*args, **kwargs))
    
    if patches_per_image >= 2:
        dataset = dataset.repeat(patches_per_image)
    
    if patch_size is not None:
        dataset = dataset.map(random_crop(patch_size), 
                              num_parallel_calls=num_parallel_calls)
        
     # apply image augmentations.
    if augmentations is not None:
        for augmentation_fn in augmentations:
            dataset = dataset.map(
                augmentation_fn, num_parallel_calls=num_parallel_calls)
        
    dataset = dataset.shuffle(shuffle_buffer)
    
    return dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

Let's iterate over the dataset to see if everything looks reasonable.

In [None]:
patch_size = (128, 128, 1)

dataset = create_dataset('data/training/', 'raw', 'mask', '*.TIF', 
                         patch_size=patch_size, batch_size=1,
                         augmentations=[random_axis_flip(1),
                                        random_axis_flip(2),
                                        gaussian_noise(0.1, 0.25, 0)])

for img, annot in dataset:
    axarr = plt.subplots(1, 2, figsize=(8, 4))[1]
    axarr[0].imshow(img.numpy().squeeze(), cmap='Greys')
    axarr[1].imshow(annot.numpy().squeeze())
    plt.show()

We also create a separate dataset for validation. This time, 
we want to yield batches of 1 without any cropping and no augmentation.

In [None]:
val_dataset = create_dataset('data/val/', 'raw', 'mask', '*.TIF', batch_size=1)

### Construct a UNet

In [None]:
def build_unet(n_levels, initial_features, in_channels=1, out_channels=1):
    '''a quick and dirty implementation of a unet.
    
    '''
    # unexposed parameters 
    n_blocks = 2
    kernel_size = 3
    pooling_size = 2
    
    inputs = keras.layers.Input(shape=(None, None, in_channels), name='img')
    x = inputs
    
    convparams = dict(kernel_size=kernel_size, activation='relu', padding='same')
    
    # downstream
    skips = {}
    for level in range(n_levels - 1):
        for _ in range(n_blocks):
            x = keras.layers.Conv2D(initial_features * 2 ** level, **convparams)(x)
        skips[level] = x
        x = keras.layers.MaxPool2D(pooling_size)(x)
        
    # lowest level
    for _ in range(n_blocks):
        x = keras.layers.Conv2D(initial_features * 2 ** (n_levels - 1), **convparams)(x)
    
    # upstream
    for level in reversed(range(n_levels - 1)):
        
        # 
        x = keras.layers.Conv2DTranspose(initial_features * 2 ** level, 
                                         strides=pooling_size,
                                         kernel_size=kernel_size, 
                                         padding=convparams['padding'], 
                                         activation=convparams['activation'])(x)
        x = keras.layers.Concatenate()([x, skips[level]])
        
        for _ in range(n_blocks):
            x = keras.layers.Conv2D(initial_features * 2 ** level, **convparams)(x)
    
        
    activation = 'sigmoid' if out_channels == 1 else 'softmax'
    x = keras.layers.Conv2D(out_channels, kernel_size=1, activation=activation, padding='same', 
                            name='mask')(x)
    
    return keras.Model(inputs=[inputs], outputs=[x])

model = build_unet(2, 32)

In [None]:
model.summary(110)

In [None]:
model.compile(optimizer='adam',
              learning_rate=0.01,
              loss='binary_crossentropy',
              metrics=['accuracy', ])

In [None]:
model.fit(dataset, epochs=20, 
          callbacks=[keras.callbacks.TensorBoard('unet-logs/')], 
          validation_data=val_dataset)

### Predict with the trained model

In [None]:
for img, annotation in val_dataset:

    # apply model to single training patch
    prob = model.predict(img)
    
    # plotting.
    axarr = plt.subplots(1, 3, figsize=(12, 5))[1]
    axarr[0].imshow(img.numpy().squeeze(), cmap='Greys')
    axarr[1].imshow(prob.squeeze())
    axarr[2].imshow(annotation.numpy().squeeze())
    for ax in axarr:
        ax.axis('off')
    plt.show()

We can also predict on an image of different size!

In [None]:
full_img = normalize(imread('data/val/raw/SIMCEPImages_A06_C23_F1_s11_w2.TIF'))
full_img = full_img[None, ..., None]  # add batch and channel dimension

print(full_img.shape)

full_probs = model.predict(full_img)

In [None]:
plt.imshow(full_img.squeeze())
plt.show()

plt.imshow(full_probs.squeeze())
plt.show()