# 2D model without preprocessing

## Colab integration

In [None]:
RUNNING_IN_COLAB = False

if RUNNING_IN_COLAB:
    REPO_URL = 'https://github.com/nicomem/imed-project.git'
    BRANCH   = 'master'
    REPO_DIR = 'imed-project'
    DATA_URL = 'https://drive.google.com/uc?id=1onHHWIhkhN5xYMit0rhhtVXlJrAlzCit'
    
    from pathlib import Path

    %cd /content

    # Download the repository
    if not Path(REPO_DIR).is_dir():
        !git clone --branch {BRANCH} --depth=1 -- {REPO_URL} {REPO_DIR}
    
    %cd {REPO_DIR}

    # Install requirements
    !pip install -r requirements.txt | grep -v 'Requirement already satisfied'
    !pip install gdown | grep -v 'Requirement already satisfied'
    
    import gdown
    if not Path('data.zip').is_file():
        gdown.download(DATA_URL, 'data.zip', quiet=False)
    
    if not Path('data').is_dir():
        !unzip -q -- data.zip
    
    %cd notebooks
    !ls

## Imports

In [None]:
# 3rd-party imports
import numpy as np
import nibabel as nib
import tensorflow as tf
import matplotlib.pyplot as plt
import keras_unet
import cv2

from tensorflow import keras

## Get dataset & split train/test

In [None]:
from utils.load_data import get_dataset, SlicesSequence, CachedSlicesSequence

train_nib, val_nib, test_nib = get_dataset('../data', verbose=True)

In [None]:
train_nib.keys()

In [None]:
[len(v) for v in train_nib.values()]

In [None]:
[len(v) for v in val_nib.values()]

In [None]:
[len(v) for v in test_nib.values()]

## Load train & analyze

In [None]:
batch_size = 32
shuffle = True
target_height = 256
target_width = 256
img_size = (target_height, target_width)
num_classes = 1

train_seq_uncached = SlicesSequence(train_nib,
                                    target_height, target_width, 
                                    batch_size, shuffle)
train_seq = CachedSlicesSequence(train_seq_uncached, batch_size, shuffle)
len(train_seq)

In [None]:
# Different number of slices & X/Y dimensions for inputs & targets
x,y = train_seq[0]
print(x.shape)
print(y.shape)
print(x[0,...,0].shape, x[0,...,1].shape, y[0].shape)

print('---')

x,y = train_seq[-1]
print(x.shape)
print(y.shape)
print(x[0,...,0].shape, x[0,...,1].shape, y[0].shape)

In [None]:
plt.figure(figsize=(9, 6))
i_data = 15

x,y = train_seq[0]
plt.subplot(2, 3, 1); plt.imshow(x[i_data,...,0])
plt.subplot(2, 3, 2); plt.imshow(x[i_data,...,1])
plt.subplot(2, 3, 3); plt.imshow(y[i_data])

x,y = train_seq[-1]
plt.subplot(2, 3, 4); plt.imshow(x[i_data,...,0])
plt.subplot(2, 3, 5); plt.imshow(x[i_data,...,1])
plt.subplot(2, 3, 6); plt.imshow(y[i_data])

## Prepare the model

In [None]:
from keras_unet.models import custom_unet

input_shape = (target_height, target_width, 2)
keras.backend.clear_session()
model = custom_unet(
    input_shape,
    num_classes=1,
    use_batch_norm=True,
    filters=32,
    num_layers=3,
    dropout=0.1,
    output_activation='sigmoid'
)
model.summary()

## Train the model

In [None]:
val_seq_uncached = SlicesSequence(val_nib,
                                  target_height, target_width,
                                  batch_size, shuffle)
val_seq = CachedSlicesSequence(val_seq_uncached, batch_size, shuffle)
len(val_seq)

In [None]:
usable_train_data = np.any(train_seq.Y, axis=(1,2))
np.count_nonzero(usable_train_data), train_seq.Y.shape[0]

In [None]:
np.count_nonzero(np.any(val_seq.Y, axis=(1,2))), val_seq.Y.shape[0]

In [None]:
from keras.callbacks import ModelCheckpoint

model_filename = 'segm_model_v0.h5'
callback_checkpoint = ModelCheckpoint(
    model_filename, 
    verbose=1, 
    monitor='val_loss', 
    save_best_only=True,
)

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
    loss="binary_crossentropy",
    metrics=[tf.keras.metrics.Recall(name='recall'),
             tf.keras.metrics.Precision(name='precision')]
)

In [None]:
history = model.fit(
    train_seq.X[usable_train_data],
    train_seq.Y[usable_train_data],
    epochs=40,
    validation_data=(val_seq.X, val_seq.Y),
    batch_size=32,
    shuffle=True
)

In [None]:
def rolling_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

ravg_w = 3

plot_cols = ['loss', 'precision', 'recall']
plt.figure(figsize=(20,6))
for i, col in enumerate(plot_cols):
    plt.subplot(1, len(plot_cols), i+1)

    plt.plot(rolling_average(history.history[col], ravg_w), label='train')
    plt.plot(rolling_average(history.history[f'val_{col}'], ravg_w), label='val')

    plt.title(col)
    plt.xlabel('Epoch')
    plt.ylabel(col)
    plt.legend()

plt.show()

In [None]:
def check_model_with_set(model, seq, N = 10, res_to_bool = True):
    # Predict the test set
    Y_pred = model.predict(seq.X)
    print('Y_pred shape:', Y_pred.shape)

    if res_to_bool:
        Y_pred = Y_pred > 0.5
    
    # Compare some predictions to the ground truth
    i_samples = np.random.choice(np.arange(0, Y_pred.shape[0]), size=N, replace=False)

    plt.figure(figsize=(20,5*N))
    for i, i_sample in enumerate(i_samples):
      x = seq.X[i_sample]
      y = seq.Y[i_sample]
      plt.subplot(N, 4, 4*i+1)
      plt.imshow(x[...,0])
      plt.title('T1')

      plt.subplot(N, 4, 4*i+2)
      plt.imshow(x[...,1])
      plt.title('FLAIR')

      plt.subplot(N, 4, 4*i+3)
      plt.imshow(np.reshape(y, (target_height, target_width)))
      plt.title('Ground Truth (wmh)')

      plt.subplot(N, 4, 4*i+4)
      plt.imshow(np.reshape(Y_pred[i_sample], (target_height, target_width)))
      plt.title('Predicted (wmh)')

In [None]:
check_model_with_set(model, val_seq, res_to_bool=False)

## Test the model

In [None]:
# Fetch the test set
test_seq_uncached = SlicesSequence(test_nib,
                                   target_height, target_width,
                                   batch_size, shuffle)
test_seq = CachedSlicesSequence(val_seq_uncached, batch_size, shuffle)
print('Number of slices in test set:', len(test_seq) * batch_size)

In [None]:
check_model_with_set(model, test_seq, N=20, res_to_bool=True)