# WMH Segmentation Challenge

A notebook that can build, train and run a model to tackle the White Matter Hyperintensities (WMH) segmentation challenge.

A few options can be set to modify the method used (3D slices, preprocessed images...)

## Colab integration

A cell that can be run to easily run the notebook in Colab.

With that, the notebook can be imported in Colab and the cell will download everything necessary to run it.

Constants are defined at the start to be able to modify the execution (e.g. downloading another git branch than master)

In [None]:
RUNNING_IN_COLAB = True

if RUNNING_IN_COLAB:
    REPO_URL = 'https://github.com/nicomem/imed-project.git'
    BRANCH   = 'final-touches'
    REPO_DIR = 'imed-project'
    DATA_URL = 'https://drive.google.com/uc?id=1onHHWIhkhN5xYMit0rhhtVXlJrAlzCit'
    
    from pathlib import Path

    %cd /content

    # Download the repository
    if not Path(REPO_DIR).is_dir():
        !git clone --branch {BRANCH} --depth=1 -- {REPO_URL} {REPO_DIR}
    
    %cd {REPO_DIR}

    # Install requirements
    !pip install -r requirements.txt | grep -v 'Requirement already satisfied'
    !pip install gdown | grep -v 'Requirement already satisfied'
    
    import gdown
    if not Path('data.zip').is_file():
        gdown.download(DATA_URL, 'data.zip', quiet=False)
    
    if not Path('data').is_dir():
        !unzip -q -- data.zip
    
    %cd notebooks
    %ls

## Imports

In [None]:
# 3rd-party imports
import numpy as np
import nibabel as nib
import tensorflow as tf
import matplotlib.pyplot as plt
import keras_unet
import skimage

from tensorflow import keras

## Get dataset & split train/test

Get the dataset files and split them into train/validation/test sets.

The split is done by scan, so that different slices of the same scan will not be in multiple sets.

The data is not loaded here, only the nibabel objects are created, which only loads information about the data (slices shape, etc.)

In [None]:
from utils.load_data import *

train_nib, val_nib, test_nib = get_dataset('../data', verbose=True)

print('\n{...}_nib keys:', train_nib.keys())
print('train_nib:', [len(v) for v in train_nib.values()])
print('val_nib:',   [len(v) for v in val_nib.values()])
print('test_nib:',  [len(v) for v in test_nib.values()])

## Load train & analyze

The train set data is loaded here.

The `SlicesSequence` object herits the keras `Sequence` object which can be used to load lazily the data when fitting the model.

However, if we have enough RAM, we can speed-up the data loading time by loading all slices at once, which is done here by the creation of `CachedSlicesSequence`.

We also reshape the slices by cropping or padding them to have the same shape, so that they can be easily be transfered to the model (each slice in a batch must have the same shape).

In [None]:
# Control whether to use a 3D model and whether to add a preprocessing phase
ENABLE_3D = False
ENABLE_PREPROCESSING = False

In [None]:
# On 3D, we must reduce the batch size or else we run out of GPU memory
batch_size = 16 if ENABLE_3D else 32
radius_3D = 1 if ENABLE_3D else 0
preprocess = ENABLE_PREPROCESSING

target_height = 256
target_width = 256
shuffle = True
num_classes = 1

img_size = (target_height, target_width)
num_channels_per_slice = 2 + preprocess

# (current slice + radius before + radius after) * number of channels per slice
total_num_channels = (radius_3D * 2 + 1) * num_channels_per_slice

slices_seq_kwargs = {
    'target_height': target_height,
    'target_width': target_width,
    'slices3D_radius': radius_3D,
    'batch_size': batch_size,
    'shuffle': shuffle
}

# Create a lazy-loading sequence
train_seq_uncached = SlicesSequence(train_nib, **slices_seq_kwargs)

# Load all slices to speed up the training
# Remove the slices where no wmh is found,
# this leads to faster and more stable training
train_seq = CachedSlicesSequence(train_seq_uncached,
                                 preprocess=preprocess,
                                 remove_no_wmh=True)

print('Number of trainable slices:', len(train_seq.indexes))
print('Number of batch:', len(train_seq))
print('Batch size:', train_seq.batch_size)
print('Slices not trained per epoch:', len(train_seq.indexes) - len(train_seq) * train_seq.batch_size)

In [None]:
# The sequence can be accessed by batch like a list
x,y = train_seq[0]
print(x.dtype, y.dtype)
print(x.shape, y.shape)

x,y = train_seq[-1]
print(x.dtype, y.dtype)
print(x.shape, y.shape)

In [None]:
# Plotting a random slice from 2 batchs
i_data = 10

x,y = train_seq[0]
plt.figure(figsize=(12, 6))
for i in range(x.shape[-1]):
    plt.subplot(1, total_num_channels+1, i+1)
    plt.imshow(x[i_data,...,i])
plt.subplot(1, total_num_channels+1, total_num_channels+1)
plt.imshow(y[i_data])
plt.show()

x,y = train_seq[-1]
plt.figure(figsize=(12, 6))
for i in range(total_num_channels):
    plt.subplot(1, total_num_channels+1, i+1)
    plt.imshow(x[i_data,...,i])
plt.subplot(1, total_num_channels+1, total_num_channels+1)
plt.imshow(y[i_data])
plt.show()

## Prepare the model

The model chosen is a UNet, created with the help of a library to avoid all the boilerplate.

It takes the T1 and FLAIR images of a slice and returns an image containing the probability of WMH for each pixel.

The resulting probabilities can be transformed to boolean values by simply applying a threshold.

In [None]:
from keras_unet.models import custom_unet

input_shape = (target_height, target_width, total_num_channels)
keras.backend.clear_session()
model = custom_unet(
    input_shape,
    num_classes=1,
    use_batch_norm=True,
    filters=32,
    num_layers=3,
    dropout=0.1,
    output_activation='sigmoid'
)
model.summary()

## Train the model

The validation set is loaded here in the same way as the training set.

It is used to tweak hyper-parameters to improve metrics and decrease overfitting.

In [None]:
val_seq_uncached = SlicesSequence(val_nib, **slices_seq_kwargs)

# Do not remove the no-wmh slices because this is the validation set
# Modifying this set would result in adding a bias to the metrics
val_seq = CachedSlicesSequence(val_seq_uncached,
                               preprocess=preprocess)
len(val_seq)

In [None]:
class Dice(keras.metrics.Metric):
    def __init__(self, name='dice', threshold=0.5, **kwargs):
        super(Dice, self).__init__(name=name, **kwargs)
        
        self.TP = keras.metrics.TruePositives(thresholds=threshold)
        self.FP = keras.metrics.FalsePositives(thresholds=threshold)
        self.TN = keras.metrics.TrueNegatives(thresholds=threshold)
        self.FN = keras.metrics.FalseNegatives(thresholds=threshold)
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        self.TP.update_state(y_true, y_pred, sample_weight=sample_weight)
        self.FP.update_state(y_true, y_pred, sample_weight=sample_weight)
        self.TN.update_state(y_true, y_pred, sample_weight=sample_weight)
        self.FN.update_state(y_true, y_pred, sample_weight=sample_weight)
    
    def result(self):
        TP = self.TP.result()
        FP = self.FP.result()
        TN = self.TN.result()
        FN = self.FN.result()

        return 2 * TP / (2 * TP + FP + FN)
    
    def reset_states(self):
        self.TP.reset_states()
        self.FP.reset_states()
        self.TN.reset_states()
        self.FN.reset_states()


d = Dice()
d.update_state([0,1,1,1], [0,0,0,1])
print(d.result().numpy())

d.update_state([0,1,1,1], [0.2,0.4,0.45,0.6])
print(d.result().numpy())

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
    loss="binary_crossentropy",
    metrics=[
        Dice(name='dice'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.Precision(name='precision')
    ]
)

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint

checkpoint_filename = f"model_radius-{radius_3D}_{'prepro' if preprocess else 'no-prepro'}.h5"
print(checkpoint_filename)
callback_checkpoint = ModelCheckpoint(
    checkpoint_filename, 
    verbose=1, 
    monitor='val_dice',
    mode='max',
    save_best_only=True,
)

In [None]:
history = model.fit(
    train_seq,
    epochs=50,
    validation_data=val_seq,
    shuffle=shuffle,
    callbacks=[callback_checkpoint]
)

In [None]:
# Load the best model checkpoint
model.load_weights(checkpoint_filename)

In [None]:
# Display the learning curves with a rolling average (to make the plots easier to analyse)

def rolling_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

ravg_w = 3

plot_cols = ['loss', 'precision', 'recall', 'dice']
plt.figure(figsize=(20,6))
for i, col in enumerate(plot_cols):
    plt.subplot(1, len(plot_cols), i+1)

    plt.plot(rolling_average(history.history[col], ravg_w), label='train')
    plt.plot(rolling_average(history.history[f'val_{col}'], ravg_w), label='val')

    plt.title(col)
    plt.xlabel('Epoch')
    plt.ylabel(col)
    plt.legend()

plt.show()

In [None]:
def check_model_with_set(model, seq, N = 10, bool_threshold = None):
    '''
    Display the model results on a random sample of data.
    
    Parameters:
    -----------
    model:
        The trained model.
    seq: SlicesSequence
        The dataset to check.
    N: int
        The number of samples to check.
    bool_threshold: Option[int]
        The threshold to apply to the model results.
        Must be in the range: [0.0, 1.0].
        Set to None to display the raw results (probabilities).
    '''
    
    # Pick a random sample of the dataset
    i_samples = np.random.choice(np.arange(0, len(seq.X)), size=N, replace=False)
    
    # Predict the samples
    X = seq.X[i_samples]
    Y_gt = seq.Y[i_samples]
    Y_pred = model.predict(X)

    if bool_threshold:
        Y_pred = (Y_pred > bool_threshold).astype(np.bool)

    # Helper function to reshape the images, in case there are more single dimensions
    reshape_img = lambda img: np.reshape(img, (target_height, target_width))
        
    # Compare some predictions to the ground truth
    plt.figure(figsize=(20,5*N))
    for i in range(N):
        x_T1    = reshape_img(X[i,...,0])
        x_FLAIR = reshape_img(X[i,...,1])
        y_gt    = reshape_img(Y_gt[i])
        y_pred  = reshape_img(Y_pred[i])

        plt.subplot(N, 4, 4*i+1)
        plt.imshow(x_T1)
        plt.title('T1')

        plt.subplot(N, 4, 4*i+2)
        plt.imshow(x_FLAIR)
        plt.title('FLAIR')

        plt.subplot(N, 4, 4*i+3)
        plt.imshow(y_gt)
        plt.title('Ground Truth (wmh)')

        plt.subplot(N, 4, 4*i+4)
        plt.imshow(y_pred)
        plt.title('Predicted (wmh)')

In [None]:
check_model_with_set(model, val_seq)

## Evaluating the model

Evaluate the model on the testing set.

This must be done after the model has "good results".

**This must be used to tweak hyper-parameters** (or else, this defeats the goal of the testing set).

In [None]:
# Fetch the test set
test_seq_uncached = SlicesSequence(test_nib, **slices_seq_kwargs)
test_seq = CachedSlicesSequence(val_seq_uncached)
print('Number of slices in test set:', len(test_seq) * batch_size)

In [None]:
model.evaluate(test_seq)

In [None]:
check_model_with_set(model, test_seq, N=20, bool_threshold=0.5)