# 2D model with preprocessing

## Colab integration

In [None]:
RUNNING_IN_COLAB = True
Branch = "new-unet"

if RUNNING_IN_COLAB:
    REPO_URL = 'https://github.com/nicomem/imed-project.git'
    REPO_DIR = 'imed-project'
    DATA_URL = 'https://drive.google.com/uc?id=1onHHWIhkhN5xYMit0rhhtVXlJrAlzCit'
    
    from pathlib import Path

    %cd /content

    # Download the repository
    if not Path(REPO_DIR).is_dir():
        !git clone {REPO_URL} {REPO_DIR}
    
    %cd {REPO_DIR}
    !git checkout $Branch
    # Install requirements
    !pip install -r requirements.txt | grep -v 'Requirement already satisfied'
    !pip install gdown | grep -v 'Requirement already satisfied'
    
    import gdown
    if not Path('data.zip').is_file():
        gdown.download(DATA_URL, 'data.zip', quiet=False)
    
    if not Path('data').is_dir():
        !unzip data.zip
    
    %cd notebooks
    !ls

## Imports

In [None]:
# 3rd-party imports
import numpy as np
import nibabel as nib
import tensorflow as tf
import tensorflow.keras as k
import matplotlib.pyplot as plt
import keras_unet
from tensorflow import keras

## Get dataset & split train/test

In [None]:
from utils.load_data import get_dataset, SlicesSequence, CachedSlicesSequence

train_nib, val_nib = get_dataset('../data', verbose=True)

## Load data

In [None]:
batch_size = 32
shuffle = True
preprocess = True
target_height = 256
target_width = 256
img_size = (target_height, target_width)
num_classes = 1

train_seq_uncached = SlicesSequence(train_nib,
                                    target_height, target_width, 
                                    batch_size, shuffle)
train_seq = CachedSlicesSequence(train_seq_uncached, batch_size, shuffle, preprocess)
len(train_seq)

In [None]:
val_seq_uncached = SlicesSequence(val_nib,
                                  target_height, target_width,
                                  batch_size, shuffle)
val_seq = CachedSlicesSequence(val_seq_uncached, batch_size, shuffle, preprocess)
len(val_seq)

## Prepare the model

In [None]:
from keras_unet.models import custom_unet

# Beware 3 channels because of preprocessing
input_shape = (target_height, target_width, 3)
keras.backend.clear_session()
model = custom_unet(
    input_shape,
    num_classes=1,
    use_batch_norm=True,
    filters=32,
    num_layers=3,
    #dropout=0.3,
    output_activation='sigmoid'
)

In [None]:
from keras import backend as K

def dice_coef(y_true, y_pred, smooth=1):
    """
    Dice = (2*|X & Y|)/ (|X|+ |Y|)
         =  2*sum(|A*B|)/(sum(A^2)+sum(B^2))
    ref: https://arxiv.org/pdf/1606.04797v1.pdf
    """
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    return (2. * intersection + smooth) / (K.sum(K.square(y_true),-1) + K.sum(K.square(y_pred),-1) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

In [None]:
from keras.callbacks import ModelCheckpoint

model_filename = 'segm_model_prepro_v0.h5'
callback_checkpoint = ModelCheckpoint(
    model_filename, 
    verbose=1, 
    monitor='val_loss', 
    save_best_only=True,
)

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
    loss=dice_coef_loss,
    #loss="binary_crossentropy",
    metrics=[tf.keras.metrics.Recall(), tf.keras.metrics.Precision()]
)

In [None]:
train_seq.Y = train_seq.Y.astype(np.float32)
val_seq.Y = val_seq.Y.astype(np.float32)

In [None]:
history = model.fit(train_seq, epochs=15, validation_data=val_seq, callbacks=[callback_checkpoint])

## Check the results

In [None]:
def rolling_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

ravg_w = 1

plt.plot(rolling_average(history.history['loss'], ravg_w), label='train')

plt.plot(rolling_average(history.history['val_loss'], ravg_w), label='val')

plt.legend()
plt.show()

In [None]:
Y_pred = model.predict(val_seq.X)
print(Y_pred.shape)

In [None]:
N = 10
i_samples = np.random.choice(np.arange(0, Y_pred.shape[0]), size=N, replace=False)

plt.figure(figsize=(16,16))
for i, i_sample in enumerate(i_samples):
  x = val_seq.X[i_sample]
  y = val_seq.Y[i_sample]
  print(i_sample)
  plt.subplot(N, 4, 4*i+1)
  plt.imshow(x[...,0])

  plt.subplot(N, 4, 4*i+2)
  plt.imshow(x[...,1])

  plt.subplot(N, 4, 4*i+3)
  plt.imshow(np.reshape(y, (target_height, target_width)))

  plt.subplot(N, 4, 4*i+4)
  plt.imshow(np.reshape(Y_pred[i_sample], (target_height, target_width)))