# U-Net segmentation test
## Dataset: https://www.kaggle.com/datasets/hamzamohiuddin/isbi-2012-challenge

In [1]:
train_model = True
batch_size = 8
max_epochs = 200
checkpoint_dir = './models/ISBI2012/ckpt/'
save_path = './models/ISBI2012/model.keras'
load_path = './models/ISBI2012/model.keras'

## Imports

In [2]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import gc, os, cv2

In [None]:
import tensorflow as tf
import tensorflow.keras.backend as K

In [None]:
from unet.utils import UNetHelper
from unet.losses import IoU, dice_loss, unet_sample_weights
from unet.augmentation import elastic_deformation, grid_deformation

In [5]:
import math
import warnings

lcm = lambda x, y: x * y // math.gcd(x, y)

### Set custom options

In [6]:
pd.pandas.set_option('display.max_columns', None)
#pd.pandas.set_option('display.max_rows', None)
#np.set_printoptions(threshold=sys.maxsize)

In [7]:
tf.get_logger().setLevel('ERROR')

## Strategy and Random Seed

In [8]:
reset_seed = lambda seed=42: tf.keras.utils.set_random_seed(seed)
reset_seed()

In [None]:
gpus = len(tf.config.list_physical_devices('GPU'))

if gpus <= 1: 
    strategy = tf.distribute.OneDeviceStrategy(device="/GPU:0")
else: 
    strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.ReductionToOneDevice())
    
n_devices = strategy.num_replicas_in_sync
print(f'Using {n_devices} devices.')
print(f'Using {strategy.__class__.__name__}.')

## Data Importation

In [10]:
base_dir = 'cell_images/{0}'

img_shape, img_mode = (512, 512, 1), cv2.IMREAD_GRAYSCALE #cv2.IMREAD_COLOR
mask_shape, mask_mode = (512, 512), cv2.IMREAD_GRAYSCALE

# 30 training samples, 30 test samples.
data_len = 60
data_type = 'float32'

In [None]:
X = np.zeros((data_len, img_shape[0], img_shape[1], img_shape[2]), data_type)
y = np.zeros((data_len, mask_shape[0], mask_shape[1]), np.int32)
sample_weights = np.zeros((data_len, mask_shape[0], mask_shape[1]), data_type)

for i, (image, mask) in enumerate(zip(os.listdir(base_dir.format(f"/train/imgs/")), 
                                      os.listdir(base_dir.format(f"/train/labels/")))):
    X[i] += np.expand_dims(cv2.imread(base_dir.format(f"/train/imgs/{image}"), img_mode), -1) / 255.
    msk = cv2.imread(base_dir.format(f"/train/labels/{mask}"), mask_mode)
    #_, msk = cv2.threshold(msk, 0, 255, cv2.THRESH_OTSU)
    y[i] += msk.astype(np.int32) // 255

    X[i + data_len // 2] = np.expand_dims(cv2.imread(base_dir.format(f"/test/imgs/{image}"), img_mode), -1) / 255.
    msk = cv2.imread(base_dir.format(f"/test/labels/{mask}"), mask_mode)
    #_, msk = cv2.threshold(msk, 0, 255, cv2.THRESH_OTSU)
    y[i + data_len // 2] += msk.astype(np.int32) // 255

    sample_weights[i] += unet_sample_weights(y[i], data_type=data_type)
    sample_weights[i + data_len // 2] += unet_sample_weights(y[i + data_len // 2], data_type=data_type)

gc.collect()
print('Done.')

## Data augmentation pipeline

In [12]:
@tf.function
def pipeline(X, y, w):
    y = tf.expand_dims(y, axis=-1)
    w = tf.expand_dims(w, axis=-1)
    # Horizontal flip.
    if tf.random.uniform((), 0., 1.) >= 0.5:
        X = tf.image.flip_left_right(X)
        y = tf.image.flip_left_right(y)
        w = tf.image.flip_left_right(w)
    # Vertical flip.
    if tf.random.uniform((), 0., 1.) >= 0.5:
        X = tf.image.flip_up_down(X)
        y = tf.image.flip_up_down(y)
        w = tf.image.flip_up_down(w)
    # Grid deformation.
    if tf.random.uniform((), 0., 1.) >= 0.5:
        grid_size = 5
        distort_limits = (-.5, .5)
        X = grid_deformation(X, distort_limits=distort_limits, grid_size=grid_size, order=1)
        y = grid_deformation(y, distort_limits=distort_limits, grid_size=grid_size, order=0)
        w = grid_deformation(w, distort_limits=distort_limits, grid_size=grid_size, order=0)
    # Elastic deformation
    if tf.random.uniform((), 0., 1.) >= 0.5:
        alpha = 50.
        sigma = 3.
        auto_kSize = True
        X = elastic_deformation(X, alpha=alpha, sigma=sigma, auto_kSize=auto_kSize, order=1)
        y = elastic_deformation(y, alpha=alpha, sigma=sigma, auto_kSize=auto_kSize, order=0)
        w = elastic_deformation(w, alpha=alpha, sigma=sigma, auto_kSize=auto_kSize, order=0)
    return [X, tf.squeeze(y), tf.squeeze(w)]

In [None]:
i = 15

reset_seed(420)
tmp = pipeline(X[i], y[i], sample_weights[i])

fig, ax = plt.subplots(1,2,figsize=(14,7))

ax[0].set_title("Original")
ax[0].imshow(X[i], cmap="gray")
ax[0].axis("off")
ax[1].set_title("Augmented")
ax[1].imshow(tmp[0], cmap="gray")
ax[1].axis("off")
plt.show()

## Create TF Datasets 

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices((X[:data_len // 2], 
                                                y[:data_len // 2], 
                                                sample_weights[:data_len // 2]))

val_ds = tf.data.Dataset.from_tensor_slices((X[data_len // 2:], 
                                             y[data_len // 2:]))

train_ds = train_ds.map(pipeline, num_parallel_calls=tf.data.AUTOTUNE
                        ).shuffle(train_ds.cardinality(), reshuffle_each_iteration=True
                                  ).repeat(lcm(batch_size, data_len // 2) // (data_len // 2)).batch(batch_size, drop_remainder=False, num_parallel_calls=tf.data.AUTOTUNE
                                                                                                    ).prefetch(tf.data.AUTOTUNE)

val_ds = val_ds.batch(2 * batch_size, drop_remainder=False)

train_ds = strategy.experimental_distribute_dataset(train_ds)
val_ds = strategy.experimental_distribute_dataset(val_ds)
gc.collect()

## Train model

In [15]:
def train(helper, train_dataset, val_dataset=None, epochs=100, ckpt_every=10, plot_every=1):
    history = []
    ds_card = train_dataset.cardinality
    for epoch in range(epochs):
        epoch += 1
        print(f'\nEpoch {epoch}/{epochs}')
            
        if helper.opt_schedule is not None: 
            helper.optimizer.learning_rate = helper.opt_schedule(epoch)
        progbar = tf.keras.utils.Progbar(target=ds_card)
        for i, batch in enumerate(train_dataset):
            i += 1
            loss, acc = helper.dist_train_step(batch)
            progbar.update(i, zip(['loss', 'acc'], [loss, acc]), finalize=False)

        if val_dataset is not None:
            val_loss, val_acc = 0.0, 0.0
            for j, batch in enumerate(val_dataset):
                vloss, vacc = helper.dist_val_step(batch)
                val_loss += vloss; val_acc += vacc
            val_loss /= (j+1); val_acc /= (j+1)
            history.append([loss, acc, val_loss, val_acc])
            progbar.update(i, zip(['loss', 'acc', 'val_loss', 'val_acc', 'lr'], 
                                [loss, acc, val_loss, val_acc, helper.optimizer.learning_rate.numpy()]), finalize=True)
        else: 
            history.append([loss, acc])
            progbar.update(i, zip(['loss', 'acc', 'lr'], [loss, acc, helper.optimizer.learning_rate.numpy()]), finalize=True)

        if type(ckpt_every) is int:
            if epoch % ckpt_every == 0: helper.checkpoint.save(helper.checkpoint_dir)
            
        if type(plot_every) is int:
            if epoch % plot_every == 0:
                plt.close()
                idx = np.random.choice(np.arange(data_len // 2, data_len, 1))
                image_list = [X[idx], y[idx], np.squeeze(helper.model(X[idx:idx+1], training=False).numpy().argmax(axis=-1))]
                image_list = [(255. * img).astype('uint8') if img.dtype!='uint8' else img for img in image_list]
                fig, ax = plt.subplots(1,3,figsize=(14,28))
                ax[0].set_title("Image")
                ax[1].set_title("Mask")
                ax[2].set_title("Predicted Mask")
                for k in range(3): 
                    ax[k].imshow(image_list[k], cmap="gray")
                    ax[k].axis('off')    
                plt.show()
    return history

In [None]:
#max_lr = 1.E-2
max_lr = 1.E-3
lr_decay_start, lr_decay_rate, lr_decay_step = (2, 0.1, 3)
model_param = {"input_shape": img_shape,
               "dropout": 0.2,
               }

try:
    del helper
except:
    pass

reset_seed()
K.clear_session()

with strategy.scope():    
    gc.collect()
    
    helper = UNetHelper(strategy=strategy,
                        model_param=model_param,
                        loss_func=tf.keras.losses.sparse_categorical_crossentropy,
                        optimizer=tf.keras.optimizers.SGD(learning_rate=max_lr, momentum=0.99),
                        #opt_schedule=tf.keras.optimizers.schedules.PiecewiseConstantDecay(boundaries=[5,], values=[1e-2, 1e-3]),
                        )
    
    if train_model:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            train(helper, train_ds, val_ds, max_epochs, ckpt_every=15, plot_every=30)
        helper.model.save(save_path)
    else: helper.model.load(load_path)

## Check predictions

In [None]:
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    pred = helper.model.predict(X[data_len // 2:])

In [None]:
# np.round wasn't working properly so I had to add a fuzz factor before rounding the results lol.
print(f"Average IoU on holdout set: {np.round(IoU(y[data_len // 2:], pred).numpy().mean() + 1E-10, 4)}")
print(f"Average Dice loss on holdout set: {np.round(dice_loss(y[data_len // 2:], pred).numpy().mean() + 1E-10, 4)}")

In [None]:
j = 0

fig, ax = plt.subplots(1, 3, figsize=(12, 36))
subtitles = ['Image', 'Mask', 'Predicted Mask']
image_list = [X[data_len // 2:][j], y[data_len // 2:][j], np.squeeze(pred[j].argmax(axis=-1))]
for i in range(3):
    ax[i].imshow(image_list[i], cmap="gray")
    ax[i].set_title(subtitles[i])              
    ax[i].axis('off')
plt.show()

In [None]:
plt.imshow(np.multiply(image_list[0], np.expand_dims(image_list[-1], axis=-1)), cmap="gray")
plt.axis('off')
plt.show()