# U-Net cell segmentation on the DIC-C2DH-HeLa dataset

## Import modules

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv
import gc, os

In [None]:
import tensorflow as tf
import tensorflow.keras.backend as K
import tensorflow_datasets as tfds

In [None]:
from unet.utils import UNetHelper
from unet.losses import IoU, dice_loss, unet_sample_weights
from unet.augmentation import elastic_deformation, grid_deformation

In [None]:
tf.get_logger().setLevel('ERROR')

# Notebook configuration

In [None]:
train_model = True
tf_dir = "TFData"
batch_size = 8
max_epochs = 280

## Random seed

For resetting the seed when running the training loop multiple times

In [None]:
reset_seed = lambda seed=42: tf.keras.utils.set_random_seed(seed)
reset_seed()

## Distributed training strategy

This selection is based off the tools I have at my disposal: either 1 GPU at work or 2 on Kaggle

In [None]:
gpus = len(tf.config.list_physical_devices("GPU"))

if gpus <= 1: 
    strategy = tf.distribute.OneDeviceStrategy(device="/GPU:0")
else: 
    strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.ReductionToOneDevice())

n_devices = strategy.num_replicas_in_sync
print(f"Using {n_devices} device(s).")
print(f"Using {strategy.__class__.__name__}.")

# Load the dataset

In [None]:
def process_img(img, mask):
    """
    Contrast Limited Adaptive Histogram Equalization (CLAHE) step, 
    followed by sample weight calculation [0.0, 1.0] normalization. 
    CLAHE uses the default OpenCV parameters.
    """
    clh = cv.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    clh_img = clh.apply(np.squeeze(img.numpy())) 
    sample_weights = unet_sample_weights(mask.numpy(), data_type=np.float32)
    return (tf.constant(np.expand_dims(clh_img / 255.0, -1), dtype=tf.float32, shape=img.get_shape()), 
            mask,
            tf.constant(sample_weights, dtype=tf.float32, shape=mask.get_shape()))

In [None]:
def min_max(arr):
    arr = np.asarray(arr)
    minimum, maximum = arr.min(), arr.max()
    return (arr - minimum) / (maximum - minimum)

In [None]:
img_shape = (512, 512, 1)
mask_shape = (512, 512)

hela_train = tfds.load("hela_train", data_dir=tf_dir)

# Cache segment 01
hela_train["01"] = hela_train["01"].map(lambda sample: tf.py_function(process_img, inp=[sample['image'], sample['mask']], 
                                                                      Tout=[tf.float32, tf.int32, tf.float32]),  
                                        num_parallel_calls=tf.data.AUTOTUNE)\
                                   .map(lambda X, y, sw: (tf.ensure_shape(X, img_shape), 
                                                          tf.ensure_shape(y, mask_shape), 
                                                          tf.ensure_shape(sw, mask_shape)))\
                                   .cache(f"{tf_dir}/TFCache/01_CLAHE_NORM")
example = list(hela_train["01"].take(2))
# Cache segment 02
hela_train["02"] = hela_train["02"].map(lambda pair: tf.py_function(process_img, inp=[pair['image'], pair['mask']], 
                                                                    Tout=[tf.float32, tf.int32, tf.float32]),  
                                        num_parallel_calls=tf.data.AUTOTUNE)\
                                   .map(lambda X, y, sw: (tf.ensure_shape(X, img_shape), 
                                                          tf.ensure_shape(y, mask_shape), 
                                                          tf.ensure_shape(sw, mask_shape)))\
                                   .cache(f"{tf_dir}/TFCache/02_CLAHE_NORM")
example += list(hela_train["02"].take(2))

In [None]:
fig, axes = plt.subplots(len(example), 3, figsize=(10, 5 * len(example)))

axes[0,0].set_title("Images")
axes[0,1].set_title("Masks")
axes[0,2].set_title("Sample weights")

for row, ex in zip(axes, example):
    for ax, img in zip(row, ex):
        ax.imshow(min_max(img), cmap="gray")
        ax.axis("off")

fig.tight_layout(h_pad=-15.0)
plt.show()

# Data augmentation

In [None]:
@tf.function
def pipeline(X, y, w):
    # Add channel axis
    y = tf.expand_dims(y, axis=-1)
    w = tf.expand_dims(w, axis=-1)
    # Horizontal flip
    if tf.random.uniform((), 0.0, 1.0) >= 0.5:
        X = tf.image.flip_left_right(X)
        y = tf.image.flip_left_right(y)
        w = tf.image.flip_left_right(w)
    # Vertical flip
    if tf.random.uniform((), 0.0, 1.0) >= 0.5:
        X = tf.image.flip_up_down(X)
        y = tf.image.flip_up_down(y)
        w = tf.image.flip_up_down(w)
    # Grid deformation
    if tf.random.uniform((), 0.0, 1.0) >= 0.5:
        grid_size = 5
        distort_limits = (-0.35, 0.35)
        X = grid_deformation(X, distort_limits=distort_limits, grid_size=grid_size, order=1)
        y = grid_deformation(y, distort_limits=distort_limits, grid_size=grid_size, order=0)
        w = grid_deformation(w, distort_limits=distort_limits, grid_size=grid_size, order=0)
    # Elastic deformation
    if tf.random.uniform((), 0.0, 1.0) >= 0.5:
        alpha = 100.0
        sigma = 5.0
        auto_kSize = True
        X = elastic_deformation(X, alpha=alpha, sigma=sigma, auto_kSize=auto_kSize, order=1)
        y = elastic_deformation(y, alpha=alpha, sigma=sigma, auto_kSize=auto_kSize, order=0)
        w = elastic_deformation(w, alpha=alpha, sigma=sigma, auto_kSize=auto_kSize, order=0)
    return [X, tf.squeeze(y), tf.squeeze(w)]

In [None]:
fig, axes = plt.subplots(len(example), 2, figsize=(8, 4 * len(example)))

axes[0,0].set_title("Original")
axes[0,1].set_title("Augmented")

for row, (tmp_X, tmp_y, tmp_w) in zip(axes, example):
    row[0].imshow(min_max(ex[0]), cmap="gray")
    row[0].axis("off")
    row[1].imshow(min_max(pipeline(tf.expand_dims(tmp_X, 0),
                                   tf.expand_dims(tmp_y, 0),
                                   tf.expand_dims(tmp_w, 0))[0][0]), cmap="gray")
    row[1].axis("off")

fig.tight_layout()
plt.show()

# Main training loop

In [None]:
def train(helper, train_dataset, val_dataset=None, examples=None, epochs=100, ckpt_every=10, plot_every=1, verbose=True): # A helper function I wrote in a hurry.
    history = []
    ds_card = train_dataset.cardinality
    for epoch in range(1, epochs + 1):
        print(f'\nEpoch {epoch}/{epochs}')
        # Learning rate schedule
        if helper.opt_schedule is not None: 
            helper.optimizer.learning_rate = helper.opt_schedule(epoch)
        # Create progress bar
        if verbose: 
            progbar = tf.keras.utils.Progbar(target=ds_card)
        # Run the training steps
        for i, batch in enumerate(train_dataset):
            loss, acc = helper.dist_train_step(batch)
            # Update prog bar
            if verbose:
                progbar.update(i + 1, zip(['loss', 'acc'], [loss, acc]), finalize=False)
        # Run for the validation set
        if val_dataset is not None:
            val_loss, val_acc = 0.0, 0.0
            for j, batch in enumerate(val_dataset):
                vloss, vacc = helper.dist_val_step(batch)
                val_loss += vloss
                val_acc += vacc
            val_loss /= (j + 1)
            val_acc /= (j + 1)
            history.append([loss, acc, val_loss, val_acc])
            if verbose:
                progbar.update(i, zip(['loss', 'acc', 'val_loss', 'val_acc', 'lr'], 
                                      [loss, acc, val_loss, val_acc, helper.optimizer.learning_rate.numpy()]), finalize=True)
        else: 
            history.append([loss, acc])
            if verbose:
                progbar.update(i, zip(['loss', 'acc', 'lr'], [loss, acc, helper.optimizer.learning_rate.numpy()]), finalize=True)
        # Save training checkpoint
        if type(ckpt_every) is int: 
            if epoch % ckpt_every == 0: 
                helper.checkpoint.save(helper.checkpoint_dir)
        # Plot training progression with the selected examples
        if type(plot_every) is int: 
            if epoch % plot_every == 0 and examples is not None:
                plt.close()
                X, y = list(examples.take(1))[0]
                image_list = [X.numpy()[0], y.numpy()[0], helper.model(X).numpy().argmax(axis=-1)[0]]
                image_list = [(255.0 * img).astype('uint8') if img.dtype !='uint8' else img for img in image_list]
                fig, ax = plt.subplots(1, 3, figsize=(14, 28))
                ax[0].set_title("Image")
                ax[1].set_title("Mask")
                ax[2].set_title("Predicted Mask")
                for k in range(3): 
                    ax[k].imshow(image_list[k], cmap="gray")
                    ax[k].axis("off")    
                plt.show()
    return history

## Cross-validation

Nothing too fancy: GroupKFold with each of the recordings as a group

In [None]:
max_lr = 1.E-3

lr_decay_start, lr_decay_rate, lr_decay_step = (2, 0.1, 3)

model_param = {"input_shape": img_shape,
               "dropout": 0.2}

oof_dice = []
oof_IoU = []

fold = [["01", "02"], ["02", "01"]]

for i in range(2):
    # In case we're running this cell over and over again when searching hyperparameters
    try:
        del helper
    except:
        pass
    # Restore the random seed and clear the current TF graph
    reset_seed()
    K.clear_session()
    # Set the augmentation, batching and distribution of the dataset.
    # The augmentation .map() should come after both the .batch() and .cache()
    # for increased variety of augmented samples.
    training_size = hela_train[fold[i][0]].cardinality().numpy()
    train_ds = hela_train[fold[i][0]].shuffle(training_size, reshuffle_each_iteration=True)\
                                     .repeat(np.lcm(batch_size, training_size) // (training_size))\
                                     .batch(batch_size, drop_remainder=False, num_parallel_calls=tf.data.AUTOTUNE)\
                                     .map(pipeline, num_parallel_calls=tf.data.AUTOTUNE)\
                                     .prefetch(tf.data.AUTOTUNE)
    dist_train = strategy.experimental_distribute_dataset(train_ds)
    # Same thing for the validation split
    validation_size = hela_train[fold[i][1]].cardinality().numpy()
    val_ds = hela_train[fold[i][1]].map(lambda X, y, sw: (X, y))\
                                   .cache()\
                                   .batch(2 * batch_size, drop_remainder=False, num_parallel_calls=tf.data.AUTOTUNE)
    dist_val = strategy.experimental_distribute_dataset(val_ds)
    # GPU training
    gc.collect()
    with strategy.scope():    
        gc.collect()
        helper = UNetHelper(strategy=strategy,
                            model_param=model_param,
                            loss_func=tf.keras.losses.sparse_categorical_crossentropy,
                            optimizer=tf.keras.optimizers.SGD(learning_rate=max_lr, momentum=0.99),
                            #opt_schedule=tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[5,], values=[1e-2, 1e-3]),
                            )
        if train_model:
            train(helper, dist_train, dist_val, val_ds.rebatch(1), max_epochs, ckpt_every=60, plot_every=70)#, max_epochs, ckpt_every=60, plot_every=70)
            helper.model.save(f"{tf_dir}/models/HeLa/model_fold{i + 1}.keras")
        else: 
            helper.model.load(f"{tf_dir}/models/HeLa/model_fold{i + 1}.keras")
    # Out-of-fold results
    pred = helper.model.predict(val_ds.map(lambda X, y: X))
    oof_true = list(val_ds.map(lambda X, y: y).rebatch(validation_size).take(1))[0]
    oof_dice.append(dice_loss(oof_true, pred).numpy().mean())
    oof_IoU.append(IoU(oof_true, pred).numpy().mean())    

In [None]:
print("Average out-of-fold IoU: {:.6f}".format(np.mean(oof_IoU)))
print("Average out-of-fold dice loss: {:.6f}".format(np.mean(oof_dice)))

## Training with the entire dataset

Same as before, but this time for the entire training dataset

In [None]:
try:
    del helper
except:
    pass

reset_seed()
K.clear_session()

train_ds = hela_train["01"].concatenate(hela_train["02"])
training_size = train_ds.cardinality()
train_ds = train_ds.shuffle(training_size, reshuffle_each_iteration=True)\
                   .repeat(2 * np.lcm(batch_size, training_size) // (training_size))\
                   .batch(batch_size, drop_remainder=False, num_parallel_calls=tf.data.AUTOTUNE)\
                   .map(pipeline, num_parallel_calls=tf.data.AUTOTUNE)\
                   .prefetch(tf.data.AUTOTUNE)
dist_train = strategy.experimental_distribute_dataset(train_ds)

gc.collect()
with strategy.scope():    
    gc.collect()
    helper = UNetHelper(strategy=strategy,
                        model_param=model_param,
                        loss_func=tf.keras.losses.sparse_categorical_crossentropy,
                        optimizer=tf.keras.optimizers.SGD(learning_rate=max_lr, momentum=0.99),
                        #opt_schedule=tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[5,], values=[1e-2, 1e-3]),
                        )
    if train_model:
        train(helper, dist_train, None, None, max_epochs, ckpt_every=60, plot_every=None, verbose=True)
        helper.model.save(f"{tf_dir}/models/model_all.keras")
    else: 
        helper.model.load(f"{tf_dir}/models/model_all.keras")

## Submission

In [None]:
def process_test(img):
    clh = cv.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    clh_img = clh.apply(np.squeeze(img.numpy()))
    return tf.constant(np.expand_dims(clh_img / 255.0, -1), dtype=tf.float32, shape=img.get_shape())

In [None]:
hela_test = tfds.load("hela_test", data_dir=tf_dir)

hela_sub = hela_test["01"].concatenate(hela_test["02"])\
                          .map(lambda pair: tf.py_function(process_test, inp=[pair["image"]], 
                                                           Tout=[tf.float32]),  
                               num_parallel_calls=tf.data.AUTOTUNE)\
                          .map(lambda X: tf.ensure_shape(X, img_shape))\
                          .cache(f"{tf_dir}/TFCache/SUBMISSION")
hela_sub = hela_sub.batch(hela_sub.cardinality().numpy(), num_parallel_calls=tf.data.AUTOTUNE)

In [None]:
sub_pred = helper.model.predict(hela_sub.rebatch(2 * batch_size)).argmax(axis=-1)

for i in range(sub_pred.shape[0]):
    cv.imwrite(f"Predictions/pred{str(i).zfill(4)}.png", (sub_pred[i] * 255.0).astype("uint8"))

In [None]:
X_t = list(hela_sub.take(1))[0]

In [None]:
j = 7

subtitles = ["Image", "Predicted Mask"]
image_list = [X_t[j], sub_pred[j]]

fig, ax = plt.subplots(1, 2, figsize=(12, 24))
for i in range(2):
    ax[i].imshow(image_list[i], cmap="gray")
    ax[i].set_title(subtitles[i])              
    ax[i].axis("off")
plt.show()

In [None]:
# https://www.youtube.com/watch?v=dQw4w9WgXcQ