# Train Binary Segmentation
## Instructions to run
- Modify the configuration variables in the [Constants](#constants) section and the [Training configuration](#training-configuration) section as needed.
- Customize the image transformations pipeline in the [Image transformations/augmentation](#image-transformations-augmentation) section.
- Run all the cells
- *Note*: This notebook is not ideal for the free Google Colab runtime, as it is prone to disconnecting during long training sessions. It is recommended to run this notebook on an HPC cluster, on a computer with an NVIDIA (CUDA-capable) GPU, or with Colab Pro.

## Package imports

In [1]:
from pathlib import Path
import math
import os
import requests
from zipfile import ZipFile

import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import albumentations as A
if not hasattr(Image, 'Resampling'):
    Image.Resampling = Image

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow.keras.backend as K
from tensorflow.config import list_physical_devices
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras import optimizers
from tensorflow.keras.utils import custom_object_scope

from fl_tissue_model_tools.transforms import get_elastic_dual_transform
from fl_tissue_model_tools.preprocessing import get_batch_augmentor
from fl_tissue_model_tools import models, models_util
from fl_tissue_model_tools.helper import get_img_mask_paths

# Make sure TensorFlow is using GPU - print out the available GPUs
available_gpus = list_physical_devices('GPU')
if len(available_gpus) == 0:
    print("WARNING: TensorFlow isn't using a GPU.")
else:
    print(f"Available GPUS:\n{available_gpus}")

/home/bean/fl_tissue_model_tools
Available GPUS:
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


<a name="constants"></a>
## Constants

In [None]:
training_dir = "train_binary_segmentation"

images_dir = os.path.join(training_dir, "images")
labels_dir = os.path.join(training_dir, "masks")

rand_seed = 1234

batch_size = 16

# size of the random crop window, cannot be larger than the image size
crop_window = (512, 512)

# model input shape. cropped images/masks are resampled to this shape
# target_shape must be divisible by 32, and must be square if using the patch-blending segmentor
target_shape = (256, 256)
#target_shape = crop_window     # uncomment to disable resizing the crops

checkpoint_save_path = os.path.join(training_dir, "checkpoints")

# number of times to iterate over the samples each epoch
# We can set this pretty high without overfitting since we sample crops from high-res images,
#  and apply heavy augmentations including geometric transformations.
repeat_dset_n_times = 50

# save training logs for tensorboard
log_save_path = os.path.join(training_dir, "logs")

In [None]:
# Create the training directory if it doesn't exist
Path(training_dir).mkdir(exist_ok=True)

### Hyperparameter search space

In [None]:
filter_counts_options = [
    (16, 32, 64, 128),
    (32, 64, 128, 256),
    (64, 128, 256, 512)
]

# Scale learning rate options linearly according to batch size (stays as-is when batch size is 16)
hp_search_initial_lr_options = [1e-4, 2.5e-4, 5e-4, 1e-3, 2.5e-3, 5e-3, 1e-2]
hp_search_initial_lr_options = np.array(hp_search_initial_lr_options) * batch_size/16

## Download demo training dataset

In [None]:
filename = "branching_training_data.zip"
url = f"https://github.com/fogg-lab/tissue-model-analysis-tools/raw/branching-script-update/sample_data/{filename}"
data = requests.get(url).content
archive_save_path = os.path.join(training_dir, filename)
open(archive_save_path, 'wb').write(data)

# Extract the archive
with ZipFile(archive_save_path, 'r') as data_archive:
    data_archive.extractall(training_dir)

# Delete the archive
os.remove(archive_save_path)

# Validate data paths

In [None]:
image_mask_paths = get_img_mask_paths(images_dir, labels_dir)
img_paths, mask_paths = zip(*image_mask_paths)

print(f"Found {len(image_mask_paths)} image/label pairs")

for img_path, mask_path in image_mask_paths:
    image = cv2.imread(img_path, 0)
    mask = cv2.imread(mask_path, 0)

    assert image.shape == mask.shape, (
        f"Image {img_path} and mask {mask_path} have different shapes: {image.shape} vs {mask.shape}"
    )

    if np.unique(mask).tolist() not in ([0], [255], [0, 255]):
        print(f"Mask {mask_path} has unexpected values: {np.unique(mask)}")

## Data pipeline

### Get training and validation image paths

In [None]:
n_val = int(len(img_paths) * 0.2)

# Shuffle the data image/mask keeping pairs together

#indices = np.random.permutation(len(img_paths))    # non-seeded random shuffle
rs = np.random.RandomState(seed=rand_seed)
indices = rs.permutation(len(img_paths))

img_paths = [img_paths[i] for i in indices]
mask_paths = [mask_paths[i] for i in indices]

train_img_paths = img_paths[: -n_val]
train_mask_paths = mask_paths[:-n_val]

val_img_paths = img_paths[-n_val:]
val_mask_paths = mask_paths[-n_val:]

### Compute sample weights & mean/std for training data

In [None]:
y_train_labels = models_util.load_y(train_mask_paths)

n_fg = np.sum(y_train_labels == 1)
n_bg = np.sum(y_train_labels == 0)
fg_weight = float(n_fg + n_bg) / (2.0 * n_fg)
bg_weight = float(n_fg + n_bg) / (2.0 * n_bg)
sample_weights = {0: bg_weight, 1: fg_weight}
sample_weights

In [None]:
# Get the mean and std of the training set images
x_train_imgs = models_util.load_x(img_paths)
im_mean = np.mean(x_train_imgs)
im_std = np.std(x_train_imgs)

im_mean, im_std

### Image transformations/augmentation

In [None]:
def get_resizer(ds_shape):
    def ds_im_mask(image, mask):
        """Downscale image with Lanczos interpolation and mask with nearest neighbor"""
        image = cv2.resize(image, ds_shape, interpolation=cv2.INTER_LANCZOS4)
        mask = np.array(Image.fromarray(mask).resize(ds_shape, resample=Image.Resampling.NEAREST))
        return {'image': image, 'mask': mask}
    return ds_im_mask

def get_normalizer(mean, std):
    def norm_im(image, mask):
        """Normalize image with mean and std of training set images"""
        image = ((image - mean) / std).astype(np.float32)
        return {'image': image, 'mask': mask}
    return norm_im

In [None]:
train_transforms = train_transforms = [
    A.Compose([
        A.Rotate(p=0.5, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
        A.RandomCrop(height=crop_window[0], width=crop_window[1]),
        A.Flip(p=0.5),
        A.RandomBrightnessContrast(p=0.7),
        A.OneOf([
                A.MultiplicativeNoise(p=0.5),
                A.AdvancedBlur(p=0.5)
        ], p=0.8),
    ]),     # Albumentations pipeline
    get_elastic_dual_transform(p=0.85),
    get_resizer(target_shape),
    get_normalizer(im_mean, im_std)
]

val_transforms = [
    A.Compose([
        A.RandomCrop(height=crop_window[0], width=crop_window[1]),
        A.Flip(p=0.5),
        A.RandomRotate90(p=0.5)
    ]),
    get_resizer(target_shape),
    get_normalizer(im_mean, im_std)
]

train_augmentor = get_batch_augmentor(train_transforms)
val_augmentor = get_batch_augmentor(val_transforms)

### Create the training and validation data generators

In [None]:
rs = np.random.RandomState(seed=rand_seed)

train_gen = models_util.BinaryMaskSequence(
    batch_size, train_img_paths, train_mask_paths, rs,
    models_util.load_x, models_util.load_y, augmentation_function=train_augmentor,
    sample_weights=sample_weights, repeat_n_times=repeat_dset_n_times, shuffle=True
)

val_gen = models_util.BinaryMaskSequence(
    batch_size, val_img_paths, val_mask_paths,
    rs, models_util.load_x, models_util.load_y, augmentation_function=val_augmentor,
    sample_weights=sample_weights, repeat_n_times=repeat_dset_n_times, shuffle=True)

### Test the training generator

In [None]:
X, y, _ = train_gen[1]
for i in range(batch_size):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(X[i][:,:,0], cmap='gray')
    ax[1].imshow(y[i][:,:,0], cmap='gray')
    plt.show()

### Test the validation generator

In [None]:
X, y, _ = val_gen[1]
for i in range(batch_size):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(X[i][:,:,0], cmap='gray')
    ax[1].imshow(y[i][:,:,0], cmap='gray')
    plt.show()

## Training configuration

*We'll warm up to the initial learning rate and either use a cyclic schedule or reduce it when val_loss plateaus.*

Parameters you should customize:
- `n_epochs`
- `hp_search_epochs`
- `weight_decay`

Extra parameters you should customize for cyclic schedule:
- `cyclic_lr_mult`
- `num_cycles`

Extra parameters you should customize for reduce lr on plateau:
- `lr_patience`
- `lr_reduction_factor`

In [None]:
n_epochs = 50
hp_search_epochs = n_epochs // 7

epoch_len = math.ceil(((len(train_img_paths)) * train_gen.repeat_n_times) / batch_size)
linear_warmup_steps = epoch_len

# use_cosine_decay_restarts: true to use a cyclic lr schedule, false for ReduceLROnPlateau
use_cosine_decay_restarts = True

# get optimizer options for grid search (vary the initial learning rate)
hp_search_optimizer_options = []
for initial_lr in hp_search_initial_lr_options:
    if use_cosine_decay_restarts:
        # cosine annealing parameters - timed so the last cycle ends at the end of training
        cycle_lr_mult = 0.5         # decrease initial lr at the end of each cycle (m_mul)
        num_cycles = 3              # number of full cycles

        # figure out what the first_decay_steps should be
        # start counting after the warmup steps
        total_steps = epoch_len * n_epochs - linear_warmup_steps

        # round up and add 1 to prevent an extra restart at the end
        first_decay_steps = math.ceil(total_steps / (2**num_cycles - 1)) + 1

        learning_rate = optimizers.schedules.CosineDecayRestarts(
            initial_learning_rate=initial_lr,
            first_decay_steps=first_decay_steps,
            t_mul=2.0,  # our first_decay_steps calculation assumes t_mul=2
            m_mul=cycle_lr_mult
        )
        learning_rate = optimizers.serialize(learning_rate)
    else:
        # patience: number of epochs with no improvement after which learning rate will be reduced
        patience = 4
        lr_reduction_factor = 0.5
        learning_rate = initial_lr
        reduce_lr = ReduceLROnPlateau(
            monitor='val_loss',
            factor=lr_reduction_factor,
            patience=patience,
            verbose=1,
            mode='min'
        )

    lr_schedule = models_util.WarmupSchedule(warmup_steps=linear_warmup_steps,
                                             after_warmup_lr=learning_rate)

    weight_decay = 1e-4     # weight decay for AdamW

    optimizer = optimizers.experimental.AdamW(learning_rate=lr_schedule, weight_decay=weight_decay)
    optimizer_config = optimizers.serialize(optimizer)

    hp_search_optimizer_options.append(optimizer_config)

## Hyperparameter search

In [None]:
metrics=[models_util.mean_iou_coef_factory(thresh=0.5)]
loss = BinaryCrossentropy()
callbacks = [] if use_cosine_decay_restarts else [reduce_lr]

gs = models.UNetXceptionGridSearch(
    save_dir="unet_grid_search",
    filter_counts_options=filter_counts_options,
    n_outputs=1,
    img_shape=target_shape,
    optimizer_cfg_options=hp_search_optimizer_options,
    loss=loss,
    output_act="sigmoid",
    metrics=metrics,
    callbacks=callbacks
)

with custom_object_scope({'WarmupSchedule': models_util.WarmupSchedule}):
    gs.search(
        "mean_iou_coef",
        "max",
        train_gen,
        search_verbose=True,
        validation_data=val_gen,
        epochs=hp_search_epochs
    )

In [None]:
# Get the best learning rate schedule
with custom_object_scope({'WarmupSchedule': models_util.WarmupSchedule}):
    best_lr_schedule = gs.best_optimizer_cfg['config']['learning_rate']

In [None]:
# Show the best hyperparameters
print("Best filter counts: ", gs.best_filter_counts)
print("Best optimizer: ", gs.best_optimizer_cfg)
print("Best initial learning rate: ", best_lr_schedule(linear_warmup_steps))
print("Best score: ", gs.best_score)
print("Best score index: ", gs.best_score_idx)

### Plot the best learning rate schedule

In [None]:
# If ReduceLROnPlateau is used, possible LR reductions after the warmup period are not plotted
n_steps = epoch_len * n_epochs
lr_each_step = [best_lr_schedule(i) for i in range(n_steps)]
epoch_each_step = np.arange(1, n_steps+1) / epoch_len
plt.plot(epoch_each_step, lr_each_step)
plt.xlabel('Epoch')
plt.ylabel('Learning rate')
plt.show()

## Train the model

In [None]:
K.clear_session()

with custom_object_scope({'WarmupSchedule': models.WarmupSchedule}):
    optimizer = optimizers.deserialize(optimizer_config)

model = models.build_UNetXception(1, target_shape, filter_counts=gs.best_filter_counts,
                                  output_act="sigmoid")

metrics=[models.models_util.mean_iou_coef_factory(thresh=0.5)]

model.compile(optimizer=optimizer, loss=BinaryCrossentropy(), metrics=metrics)

callbacks = [
    ModelCheckpoint(checkpoint_save_path, save_best_only=True, save_weights_only=True),
    TensorBoard(log_dir=log_save_path, histogram_freq=1)
]

callbacks = callbacks if use_cosine_decay_restarts else callbacks + [reduce_lr]

h = model.fit(train_gen, validation_data=val_gen, epochs=n_epochs, callbacks=callbacks)

## Try out the model

In [None]:
val_batch_num = 0
val_x, val_y, _ = val_gen[val_batch_num]
preds = model.predict(val_x)

In [None]:
for sample_idx in range(0, batch_size):
    image = val_x[sample_idx]
    ground_truth = val_y[sample_idx]
    prediction = preds[sample_idx]
    fig, ax = plt.subplots(1, 4, figsize=(20, 5))
    ax[0].imshow(image[:,:,0], cmap='gray')
    ax[0].set_title("Image")
    ax[1].imshow(ground_truth[:,:,0], cmap='gray')
    ax[1].set_title("True Segmentation")
    ax[2].imshow(prediction[:,:,0], cmap='gray')
    ax[2].set_title("Prediction")
    ax[3].imshow(np.greater(prediction, 0.5)[:,:,0], cmap='gray')
    ax[3].set_title("Predicted Segmentation")
    plt.show()

In [None]:
# Create patch segmentor config and save it
downsample_ratio = np.divide(target_shape, crop_window)
checkpoint_file = Path(checkpoint_save_path) / "best_model.h5"
filter_counts = gs.best_filter_counts
cfg = {
    "patch_shape": crop_window,
    "checkpoint_file": checkpoint_file,
    "filter_counts": filter_counts,
    "ds_ratio": downsample_ratio,
    "norm_mean": im_mean,
    "norm_std": im_std,
    "channels": 1
}
models_util.save_unet_patch_segmentor_cfg(cfg)