In [None]:
import pickle
import pickle
import gzip
import numpy as np
import cv2
import segmentation_models as sm
sm.set_framework('tf.keras')
sm.framework()
from tensorflow import keras

from keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Lambda
from keras.layers import Conv2DTranspose, Concatenate, Input, GlobalAveragePooling2D, Multiply, Reshape, Dense, Add
from keras.models import Model
from keras.backend import int_shape

from sklearn.model_selection import train_test_split
from skimage.filters import threshold_yen, threshold_multiotsu
from skimage.restoration import denoise_nl_means

# Helpers

In [None]:
# this function just loads a zipped pickle file and unpacks it
def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pickle.load(f)
        return loaded_object
# this function saves an object as a zipped pickle file
def save_zipped_pickle(obj, filename):
    with gzip.open(filename, 'wb') as f:
        pickle.dump(obj, f, 2)

# rescaling the image to size (x,y) by cubic interpolation
def rescale_image(img, x, y):
    return cv2.resize(img, dsize=(x, y), interpolation=cv2.INTER_CUBIC)

# blur the image with either a median or a gaussian kernel, median was
# found to perform better
def blur_image(img, mode="median"):
  if mode == "median":
    blurred = cv2.medianBlur(img,3)
  if mode == "gauss":
    blurred = cv2.GaussianBlur(img, (3,3), 1)

  return blurred

def denoise_img(img, rep=1):
  if rep > 1:
    for i in range(rep):
      img = denoise_nl_means(img)
  else:
    img = denoise_nl_means(img)

  return img

# add a layer with threshold 
def produce_mask_layer(array, method='mo'):
  regions = []
  for img in array:
    if method == 'mo':
      thresholds = threshold_multiotsu(img, classes=3)
      regions.append(np.digitize(img, bins=thresholds))
    if method == 'yen':
      threshold = threshold_yen(img)   
      regions.append(np.array(img > threshold, dtype='int'))
      
  return regions

# the idea is to add 2 channels, one containing a blurred layer and one with
# a threshold layer, the medianblur was found to work good. The thresholding
# may be still adapted
def preprocess_data(dataset, method='simple'):

  if method == 'simple':
    return np.stack([dataset, dataset, dataset], axis=-1)

  if method == 'denoise_multiotsu':
    img = [denoise_img(img, 3) for img in dataset]
    blurred = [blur_image(img) for img in dataset]
    thresh = produce_mask_layer(img, 'mo')
    return np.stack([dataset, blurred, thresh], axis=-1)
  
  if method == 'denoise_yen':
    img = [denoise_img(img, 3) for img in dataset]
    blurred = [blur_image(img) for img in dataset]
    thresh = produce_mask_layer(img, 'yen')
    return np.stack([dataset, blurred, thresh], axis=-1)

## Model

In [None]:
# the number of filters of the first layer in the unet. 
# it automatically gets doubled every deeper layer
n_filts = 64

# defining the image size, 128x128 was found to be quick and works very well
# as one can see. The size does not really affect the mask output, there
# were papers which found different ratios have not much impact to the output 
# mask result, even if they have a different ratio than 1:1
im_size = [128,128]

# the size is 128x128x3 because we get 3 layers
# instead of rgb we have normal,blurred,threshold
input_shape = (im_size[0], im_size[1], 3)

In [None]:
import tensorflow
from keras.layers.convolutional import UpSampling2D
from keras.layers import Dropout

# as suggested in https://arxiv.org/pdf/2006.04868.pdf, in usage of the double
# unet we use a squeeze excite block (need to visualize what it exactly does)
# the double unet didn't really work with vgg as encoder, so only the squeez
# excite block is kept at the moment.
#
# it squeezes the input which is a conv2d output, puts it into a dense layer
# with nfilts/8 and the expands it again. This squeeze_multiply is multiplicated
# with the input
def squeeze_excite_block(inputs, ratio=16):
    init = inputs
    channel_axis = -1
    filters = init.shape[channel_axis]
    se_shape = (1, 1, filters)

    se = GlobalAveragePooling2D()(init)
    se = Reshape(se_shape)(se)
    se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
    se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)

    x = Multiply()([init, se])
    return x


# lambda function to repeat Repeats the elements of a tensor along an axis
# by a factor of rep.
# If tensor has shape (None, 256,256,3), lambda will return a tensor of shape 
# (None, 256,256,6), if specified axis=3 and rep=2.
def repeat_elem(tensor, rep):
     return Lambda(lambda x, repnum: keras.backend.repeat_elements(x, repnum, axis=3),
                          arguments={'repnum': rep})(tensor)



# resize the down layer feature map into the same dimension as the up layer feature map
# using 1x1 conv
# :return: the gating feature map with the same dimension of the up layer feature map
def gating_signal(input, out_size):

    x = Conv2D(out_size, (1, 1), padding='same')(input)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    return x


# the convolution block consists of 2 layers of 2D concolution and is
# used in every layer of the encoder and decoder
def conv_block(input, num_filters, size=3, dropout=0):
    # first block, it's a 2d convolution with depth 3
    # the input is given by the num_filters variable, to connect it
    # to other layers of unet
    x = Conv2D(num_filters, size, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    # second block, which is the same as before, but it takes the first block
    # output as input
    x = Conv2D(num_filters, size, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = squeeze_excite_block(x)

    if dropout > 0:
      x = Dropout(dropout)(x)

    return x


# residual convolution from
# https://arxiv.org/ftp/arxiv/papers/1802/1802.06955.pdf
def res_conv_block(input, num_filters, size=3, dropout=0):
    # first block, it's a 2d convolution with depth 3
    # the input is given by the num_filters variable, to connect it
    # to other layers of unet
    x = Conv2D(num_filters, (size, size), padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    # second block, which is the same as before, but it takes the first block
    # output as input
    x = Conv2D(num_filters, (size, size), padding="same")(x)
    x = BatchNormalization()(x)
    #x = Activation("relu")(x) # no activation yet
    x = squeeze_excite_block(x)

    if dropout > 0:
      x = Dropout(dropout)(x)
    
    shortcut = Conv2D(num_filters, kernel_size=(1, 1), padding="same")(input)
    shortcut = BatchNormalization()(shortcut)

    residual = Add()([shortcut, x])
    residual = Activation("relu")(residual)
    return residual

def attention_block(x, gating, inter_shape):
  shape_x = int_shape(x)
  shape_g = int_shape(gating)

  theta_x = Conv2D(inter_shape, (2, 2), strides=(2, 2), padding="same")(x)
  shape_theta_x = int_shape(theta_x)

  phi_g = Conv2D(inter_shape, (1, 1), padding="same")(gating)
  g_upsampled = Conv2DTranspose(inter_shape, (3, 3), strides=(shape_theta_x[1] // shape_g[1], shape_theta_x[2] // shape_g[2]), padding="same")(phi_g)

  xg_concat = Add()([g_upsampled, theta_x])
  xg_act = Activation('relu')(xg_concat)
  psi = Conv2D(1, (1, 1), padding='same')(xg_act)
  xg_sigmoid = Activation('sigmoid')(psi)
  shape_sigmoid = int_shape(xg_sigmoid)
  upsample_psi = UpSampling2D(size=(shape_x[1] // shape_sigmoid[1], shape_x[2] // shape_sigmoid[2]))(xg_sigmoid)  # 
  
  upsample_psi = repeat_elem(upsample_psi, shape_x[3])
  y = Multiply()([upsample_psi, x])

  result = Conv2D(shape_x[3], (1, 1), padding='same')(y)
  result_bn = BatchNormalization()(result)
  return result_bn

# the encoder block takes the convolution block and adds maxpooling
# to connect to the next layer
def encoder_block(input, num_filters, args):
    # unpack the arguments
    res, att, drop = args
    # calling the convolution block to convolve the input with 2 layers
    # of conv2d
    if res:
      x = res_conv_block(input, num_filters, dropout=drop)
    else:
      x = conv_block(input, num_filters, dropout=drop)
    # maxpooling to prepare the output for the next layer which doubles the
    # filter size
    p = MaxPool2D((2, 2))(x)
    return x, p

# the decoder makes the up convolution and then convolution again
def decoder_block(input, skip_features, num_filters, args):
    res, att, drop = args
    
    # do the soft attention
    if att:
      gate_ = gating_signal(input, num_filters)
      att_  = attention_block(skip_features, gate_, num_filters)
      x = UpSampling2D(size= (2,2))(input)
      x = Concatenate()([x, att_])
    else:
      x = UpSampling2D(size= (2,2), data_format="channels_last")(input)
      x = Concatenate()([x, skip_features])

    # do convolution 2 times
    if res:
      x = res_conv_block(x, num_filters, dropout=drop)
    else:
      x = conv_block(x, num_filters, dropout=drop)

    return x

# defining the unet
def build_unet(input_shape, res=False, att=False, drop=0):
    inputs = Input(input_shape, dtype=tensorflow.float32)
    in_bn = BatchNormalization()(inputs)
    args = (res, att, drop)
    # layers down, each layer gives the direct output s for usage as
    # inputs to the decoder side directly, an the maxpool outputs p
    # for usage in deeper levels
    s1, p1 = encoder_block(in_bn, n_filts, args)
    s2, p2 = encoder_block(p1, n_filts*2, args)
    s3, p3 = encoder_block(p2, n_filts*4, args)
    s4, p4 = encoder_block(p3, n_filts*8, args)
    # bridge block, it makes 2 convolutions on the lowest level, no max pooling
    # since the output gous up again which is handled by the conv2dtranspose
    if res:
      b1 = res_conv_block(p4, n_filts*16, dropout=drop)
    else:
      b1 = conv_block(p4, n_filts*16, dropout=drop)
    # layers up, each layer has just one output to go into transpose of the
    # upper layer. Each decoder takes the output of the lower level (s) and the 
    # output of the encoder at the same level (d)
    d1 = decoder_block(b1, s4, n_filts*8, args)
    d2 = decoder_block(d1, s3, n_filts*4, args)
    d3 = decoder_block(d2, s2, n_filts*2, args)
    d4 = decoder_block(d3, s1, n_filts, args)
    # output layer
    outputs = Conv2D(1, (1, 1), padding="same", activation="sigmoid")(d4)
    model = Model(inputs, outputs, name="U-net")
    return model

## Data loading

In [None]:
# load data
train_data = np.load('sample.npy', allow_pickle=True)[()]
test_data = np.load('sample.npy', allow_pickle=True)[()]

In [None]:
# scale down the video data
train_data['video'] = train_data['video'].astype('float32') / 255.0
test_data['video'] = test_data['video'].astype('float32') / 255.0

# artificially augment
train_data = [train_data] * 10
test_data = [test_data] * 10

In [None]:
# generating the Xtrain and ytrain arrays

X_train_full = []
y_train_full = []
boxes_train_full = []

# extract the usable data, this just adds each frame, not caring about
# from which patient it comes
for item in train_data:
    frames = item['frames']
    # extract the video frames which are annotated
    X_train_full.append(item['video'][:,:,frames[0]])
    X_train_full.append(item['video'][:,:,frames[1]])
    X_train_full.append(item['video'][:,:,frames[2]])
    # extract the labels for classification
    y_train_full.append(item['label'][:,:,frames[0]])
    y_train_full.append(item['label'][:,:,frames[1]])
    y_train_full.append(item['label'][:,:,frames[2]])
    # extract the boxes, to match the other data, extract it 3 times
    boxes_train_full.append(item['box'])
    boxes_train_full.append(item['box'])
    boxes_train_full.append(item['box'])

# this converts the boolean array into 0 and 1
y_train_full = [np.array(i, dtype=np.uint8) for i in y_train_full]
boxes_train_full = [np.array(i, dtype=np.uint8) for i in boxes_train_full]

# resize the pictures to the defined size
for idx, item in enumerate(X_train_full):
    X_train_full[idx] = rescale_image(item, im_size[0], im_size[1])

for idx, item in enumerate(y_train_full):
    y_train_full[idx] = rescale_image(item, im_size[0], im_size[1])

for idx, item in enumerate(boxes_train_full):
    boxes_train_full[idx] = rescale_image(item, im_size[0], im_size[1])

## Preprocessing

The preprocessing was a very important step in making this pipeline achieve a good score.
The idea is to add additional channels by masks, generated using different threshold methods.

In [None]:
X_processed = preprocess_data(X_train_full, method='denoise_yen')
y_train_full = np.expand_dims(y_train_full, -1)
boxes_train_full = np.expand_dims(boxes_train_full, -1)

## Training

### Augmentation

We used the keras ImageDataGenerator to augment the dataset with shifts, zoom, and shear.
This helped with the low amount of data. Note however, that the flipping was deactivated.
This is because the model should not start to segment the left mitral valve!

In [None]:
def build_generators(batch_size,X_train, X_test, y_train, y_test):
  # fix seed for the datagenerators and batch size
  seed=24

  from keras.preprocessing.image import ImageDataGenerator

  # data generators which randomly shear, shift, zoom, flip
  # and do more juust in an infinite loop.
  # the same generator is applied to the image and the mask
  # so the loss is generated correctly
  #
  # deactivating flipping really also helped boosting the model
  # there is no flipping anyways in the testset
  img_data_gen_args = dict(rotation_range=60,
                      width_shift_range=0.3,
                      height_shift_range=0.2,
                      shear_range=0.2,
                      zoom_range=0.2,
                      horizontal_flip=False,
                      vertical_flip=False,
                      fill_mode='constant'
                      )

  mask_data_gen_args = dict(rotation_range=60,
                      width_shift_range=0.3,
                      height_shift_range=0.2,
                      shear_range=0.2,
                      zoom_range=0.2,
                      horizontal_flip=False,
                      vertical_flip=False,
                      fill_mode='constant',
                      preprocessing_function = lambda x: np.where(x > 0, 1, 0).astype(x.dtype))

  # initiate the generators, switch from train and test split to internal 0.2 split
  # to not retrain the network again after fitting
  image_data_generator = ImageDataGenerator(**img_data_gen_args)
  image_generator = image_data_generator.flow(X_train, seed=seed, batch_size = batch_size)
  valid_img_generator = image_data_generator.flow(X_test, seed=seed, batch_size = batch_size)
  mask_data_generator = ImageDataGenerator(**mask_data_gen_args)
  mask_generator = mask_data_generator.flow(y_train, seed=seed, batch_size = batch_size)
  valid_mask_generator = mask_data_generator.flow(y_test, seed=seed, batch_size = batch_size)

  # combining the generators to feed it directly as a tuple to the keras model
  def image_mask_gemerator(image_generator, mask_generator):
      train_generator = zip(image_generator, mask_generator)
      for (img, mask) in train_generator:
          yield (img, mask)

  train_gen = image_mask_gemerator(image_generator, mask_generator)
  val_gen = image_mask_gemerator(valid_img_generator, valid_mask_generator)

  return train_gen, val_gen

Train test split

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X_processed, y_train_full, random_state=0)

In [None]:
model = build_unet(input_shape=input_shape)

Configuration

In [None]:
lr = 3e-4
batch_size = 8
steps_per_epoch = 80
epochs = 1

Compile with adam and binary focal jaccard loss

In [None]:
model.compile(optimizer=keras.optimizers.Adam(learning_rate = lr), loss=sm.losses.binary_focal_jaccard_loss, metrics=[sm.metrics.iou_score])

In [None]:
train_generator, val_generator = build_generators(batch_size, X_train, X_test, y_train, y_test)

In [None]:
model.fit_generator(train_generator, validation_data = val_generator, 
                steps_per_epoch = steps_per_epoch, 
                validation_steps = steps_per_epoch,
                epochs=epochs)

## Post Processing

The prediction was converted to a binary mask via thresholding. The threshold value was found to be very important, as it also 
determines the "sharpness" of the mask. A high threshold lead to much better scores than just using 0.5.

In [None]:
y_pred = model.predict(X_test)
y_pred_thresholded = y_pred >= 0.8

In [None]:
intersection = np.logical_and(y_test, y_pred_thresholded)
union = np.logical_or(y_test, y_pred_thresholded)
iou_score = np.sum(intersection) / np.sum(union)
print("IoU socre is: ", iou_score)

The frames were furthermore processed by anisotropic diffusion and erosion.

## Result

An example prediction is shown below, where the mask is colored red. The video is in slow-motion (slowed down by factor of 3) to better see the movements.

![SegmentLocal](example.gif "segment")

![](example.gif)