# 2D model without preprocessing

## Colab integration

A cell that can be run to easily run the notebook in Colab.

With that, the notebook can be imported in Colab and the cell will download everything necessary to run it.

Constants are defined at the start to be able to modify the execution (e.g. downloading another git branch than master)

In [None]:
RUNNING_IN_COLAB = False

if RUNNING_IN_COLAB:
    REPO_URL = 'https://github.com/nicomem/imed-project.git'
    BRANCH   = 'master'
    REPO_DIR = 'imed-project'
    DATA_URL = 'https://drive.google.com/uc?id=1onHHWIhkhN5xYMit0rhhtVXlJrAlzCit'
    
    from pathlib import Path

    %cd /content

    # Download the repository
    if not Path(REPO_DIR).is_dir():
        !git clone --branch {BRANCH} --depth=1 -- {REPO_URL} {REPO_DIR}
    
    %cd {REPO_DIR}

    # Install requirements
    !pip install -r requirements.txt | grep -v 'Requirement already satisfied'
    !pip install gdown | grep -v 'Requirement already satisfied'
    
    import gdown
    if not Path('data.zip').is_file():
        gdown.download(DATA_URL, 'data.zip', quiet=False)
    
    if not Path('data').is_dir():
        !unzip -q -- data.zip
    
    %cd notebooks
    %ls

## Imports

In [None]:
# 3rd-party imports
import numpy as np
import nibabel as nib
import tensorflow as tf
import matplotlib.pyplot as plt
import keras_unet
import cv2

from tensorflow import keras

## Get dataset & split train/test

Get the dataset files and split them into train/validation/test sets.

The split is done by scan, so that different slices of the same scan will not be in multiple sets.

The data is not loaded here, only the nibabel objects are created, which only loads information about the data (slices shape, etc.)

In [None]:
from utils.load_data import get_dataset, SlicesSequence, CachedSlicesSequence

train_nib, val_nib, test_nib = get_dataset('../data', verbose=True)

print('\n{...}_nib keys:', train_nib.keys())
print('train_nib:', [len(v) for v in train_nib.values()])
print('val_nib:',   [len(v) for v in val_nib.values()])
print('test_nib:',  [len(v) for v in test_nib.values()])

## Load train & analyze

The train set data is loaded here.

The `SlicesSequence` object herits the keras `Sequence` object which can be used to load lazily the data when fitting the model.

However, if we have enough RAM, we can speed-up the data loading time by loading all slices at once, which is done here by the creation of `CachedSlicesSequence`.

We also reshape the slices by cropping or padding them to have the same shape, so that they can be easily be transfered to the model (each slice in a batch must have the same shape).

In [None]:
batch_size = 32
shuffle = True
target_height = 256
target_width = 256
img_size = (target_height, target_width)
num_classes = 1

train_seq_uncached = SlicesSequence(train_nib,
                                    target_height=target_height,
                                    target_width=target_width, 
                                    batch_size=batch_size,
                                    shuffle=shuffle)
train_seq = CachedSlicesSequence(train_seq_uncached)

print('Number of batch:', len(train_seq))
print('Total number of slices:', len(train_seq.Y))

In [None]:
# The sequence can be accessed by batch like a list
x,y = train_seq[0]
print('Batch X:', x.shape)
print('Batch Y:', y.shape)
print('A slice (T1, FLAIR, wmh):',
      x[0,...,0].shape, x[0,...,1].shape, y[0].shape)

print('---')

x,y = train_seq[-1]
print('Batch X:', x.shape)
print('Batch Y:', y.shape)
print('A slice (T1, FLAIR, wmh):',
      x[0,...,0].shape, x[0,...,1].shape, y[0].shape)

In [None]:
# Plotting a random slice from 2 batchs

plt.figure(figsize=(9, 6))
i_data = 15

x,y = train_seq[0]
plt.subplot(2, 3, 1); plt.imshow(x[i_data,...,0])
plt.subplot(2, 3, 2); plt.imshow(x[i_data,...,1])
plt.subplot(2, 3, 3); plt.imshow(y[i_data])

x,y = train_seq[-1]
plt.subplot(2, 3, 4); plt.imshow(x[i_data,...,0])
plt.subplot(2, 3, 5); plt.imshow(x[i_data,...,1])
plt.subplot(2, 3, 6); plt.imshow(y[i_data])

## Prepare the model

The model chosen is a UNet, created with the help of a library to avoid all the boilerplate.

It takes the T1 and FLAIR images of a slice and returns an image containing the probability of WMH for each pixel.

The resulting probabilities can be transformed to boolean values by simply applying a threshold.

In [None]:
from keras_unet.models import custom_unet

input_shape = (target_height, target_width, 2)
keras.backend.clear_session()
model = custom_unet(
    input_shape,
    num_classes=1,
    use_batch_norm=True,
    filters=32,
    num_layers=3,
    dropout=0.1,
    output_activation='sigmoid'
)
model.summary()

## Train the model

The validation set is loaded here in the same way as the training set.

It is used to tweak hyper-parameters to improve metrics and decrease overfitting.

In [None]:
val_seq_uncached = SlicesSequence(val_nib,
                                  target_height=target_height,
                                  target_width=target_width, 
                                  batch_size=batch_size,
                                  shuffle=shuffle)
val_seq = CachedSlicesSequence(val_seq_uncached)
len(val_seq)

Since a lot of slices do not contain any WMH, only those that contains some are kept for training.

This should result in a faster training and should not worsen the model results since the WMH areas are not located at the same place in every slice.

However, all slices in the validation and testing sets are kept, to avoid adding a bias to the metrics.

In [None]:
usable_train_data = np.any(train_seq.Y, axis=(1,2))

nb_wmh_slices = np.count_nonzero(usable_train_data)

print('Number of WMH-present training slices:', nb_wmh_slices)
print('Number of original    training slices:', len(train_seq.Y))

print(f'\n{100 * nb_wmh_slices / len(train_seq.Y):.2f}%',
      'of the training slices have been kept')

In [None]:
nb_wmh_slices_val = np.count_nonzero(np.any(val_seq.Y, axis=(1,2)))

print(f'{100 * nb_wmh_slices_val / val_seq.Y.shape[0]:.2f}%',
      'of validation slices contains WMH')

In [None]:
from keras.callbacks import ModelCheckpoint

checkpoint_filename = 'segm_model_v0.h5'
callback_checkpoint = ModelCheckpoint(
    checkpoint_filename, 
    verbose=1, 
    monitor='val_loss', 
    save_best_only=True,
)

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
    loss="binary_crossentropy",
    metrics=[tf.keras.metrics.Recall(name='recall'),
             tf.keras.metrics.Precision(name='precision')]
)

In [None]:
history = model.fit(
    train_seq.X[usable_train_data],
    train_seq.Y[usable_train_data],
    epochs=40,
    validation_data=(val_seq.X, val_seq.Y),
    batch_size=32,
    shuffle=True,
    callbacks=[callback_checkpoint]
)

In [None]:
# Load the best model checkpoint
model.load_weights(checkpoint_filename)

In [None]:
# Display the learning curves with a rolling average (to make the plots easier to analyse)

def rolling_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

ravg_w = 3

plot_cols = ['loss', 'precision', 'recall']
plt.figure(figsize=(20,6))
for i, col in enumerate(plot_cols):
    plt.subplot(1, len(plot_cols), i+1)

    plt.plot(rolling_average(history.history[col], ravg_w), label='train')
    plt.plot(rolling_average(history.history[f'val_{col}'], ravg_w), label='val')

    plt.title(col)
    plt.xlabel('Epoch')
    plt.ylabel(col)
    plt.legend()

plt.show()

In [None]:
def check_model_with_set(model, seq, N = 10, bool_threshold = None):
    '''
    Display the model results on a random sample of data.
    
    Parameters:
    -----------
    model:
        The trained model.
    seq: SlicesSequence
        The dataset to check.
    N: int
        The number of samples to check.
    bool_threshold: Option[int]
        The threshold to apply to the model results.
        Must be in the range: [0.0, 1.0].
        Set to None to display the raw results (probabilities).
    '''
    
    # Pick a random sample of the dataset
    i_samples = np.random.choice(np.arange(0, len(seq.X)), size=N, replace=False)
    
    # Predict the samples
    X = seq.X[i_samples]
    Y_gt = seq.Y[i_samples]
    Y_pred = model.predict(X)

    if bool_threshold:
        Y_pred = (Y_pred > bool_threshold).astype(np.bool)

    # Helper function to reshape the images, in case there are more single dimensions
    reshape_img = lambda img: np.reshape(img, (target_height, target_width))
        
    # Compare some predictions to the ground truth
    plt.figure(figsize=(20,5*N))
    for i in range(N):
        x_T1    = reshape_img(X[i,...,0])
        x_FLAIR = reshape_img(X[i,...,1])
        y_gt    = reshape_img(Y_gt[i])
        y_pred  = reshape_img(Y_pred[i])

        plt.subplot(N, 4, 4*i+1)
        plt.imshow(x_T1)
        plt.title('T1')

        plt.subplot(N, 4, 4*i+2)
        plt.imshow(x_FLAIR)
        plt.title('FLAIR')

        plt.subplot(N, 4, 4*i+3)
        plt.imshow(y_gt)
        plt.title('Ground Truth (wmh)')

        plt.subplot(N, 4, 4*i+4)
        plt.imshow(y_pred)
        plt.title('Predicted (wmh)')

In [None]:
check_model_with_set(model, val_seq)

## Evaluating the model

Evaluate the model on the testing set.

This must be done after the model has "good results".

**This must be used to tweak hyper-parameters** (or else, this defeats the goal of the testing set).

In [None]:
# Fetch the test set
test_seq_uncached = SlicesSequence(test_nib,
                                   target_height=target_height,
                                   target_width=target_width, 
                                   batch_size=batch_size,
                                   shuffle=shuffle)
test_seq = CachedSlicesSequence(val_seq_uncached)
print('Number of slices in test set:', len(test_seq) * batch_size)

In [None]:
test_pred = model.predict(test_seq.X)

In [None]:
precision = keras.metrics.Precision()
precision.update_state(test_seq.Y, test_pred)
print(f'Test Precision: {100 * precision.result().numpy():.2f}%')

recall = keras.metrics.Recall()
recall.update_state(test_seq.Y, test_pred)
print(f'Test Recall   : {100 * recall.result().numpy():.2f}%')

In [None]:
# check_model_with_set(model, test_seq, N=20, bool_threshold=0.5)