# Unet

In [None]:
import tensorflow as tf
import tensorflow.keras as keras
from keras.layers import *

import nibabel as nib
from nilearn import plotting
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

import mia_datasets_tf as mia_data
import mia_evaluation as mia_eval
import mia_losses_tf as mia_losses
import mia_utils
import SimpleITK as sitk
import os 
import shutil
import inspect
import random

sitk.ProcessObject_SetGlobalWarningDisplay(False)
tf.config.run_functions_eagerly(True)

In [None]:
model_directory = r"/ssd2/jupyter/MIA/simple_unet"

if not os.path.exists(model_directory):
    os.makedirs(model_directory)

## Utilities

In [None]:
def run_training(model_creator, fit_params, num_runs, directory, name):
    best_val_loss = float('inf')
    best_model_path = None

    val_losses = []

    valid_fit_args = inspect.signature(tf.keras.Model.fit).parameters.keys()
    filtered_fit_params = {key: value for key, value in fit_params.items() if key in valid_fit_args}

    for run in range(num_runs):
        print(f"Training run {run + 1}/{num_runs}\n")
        model = model_creator()

        run_model_path = os.path.join(directory, f"{name}_{run}.keras")

        checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(filepath=run_model_path, monitor='val_loss', save_best_only=True, mode='min')

        if "callbacks" in filtered_fit_params:
            filtered_fit_params['callbacks'].append(checkpoint_cb)
        else:
            filtered_fit_params['callbacks'] = [checkpoint_cb]
    
        history = model.fit(
            fit_params['training_data'],
            **filtered_fit_params)

        val_loss = min(history.history['val_loss'])
        val_losses.append(history.history['val_loss'])

        print(f"\nRun {run + 1} best val_loss = {val_loss:.6f}")

        if val_loss < best_val_loss:
            print(f"New best model found")
            best_val_loss = val_loss
            best_model_path = run_model_path

    print(f"\n Best model: {best_model_path} with val_loss = {best_val_loss:.6f}")

    return best_model_path, val_losses

In [None]:
def plot_side_by_side(images : list[np.ndarray], titles : list[str] | None):
    fig, axs = plt.subplots(1, len(images), figsize=(15, 5))
    
    for index, image in enumerate(images):
        axs[index].imshow(image[image.shape[0] // 2])
        if titles is not None:
            axs[index].set_title(titles[index])
        axs[index].axis('off')

    plt.tight_layout()

In [None]:
def show_history(history, y_lim=(0.0, 1.5)):
    plt.figure(figsize=(10, 6))
    for i, run_val_loss in enumerate(history):
        plt.plot(run_val_loss, label=f'Run {i + 1}')

    plt.title('Validation Loss Across Runs')
    plt.xlabel('Epochs')
    plt.ylabel('Validation Loss')
    plt.ylim(y_lim)
    plt.legend()
    plt.show()

In [None]:
def seed_all(seed=2141):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

In [None]:
colors = [(255, 60, 60), (140, 255, 140), (140, 200, 255)] 
custom_cmap = ListedColormap(colors)

## Model

In [None]:
def encoder_block(inputs, n_filters=32, max_pooling=True):
    layer = Conv3D(n_filters, 3, activation='relu', padding='same')(inputs)
    layer = Conv3D(n_filters, 3, activation='relu', padding='same')(layer)
    layer = BatchNormalization()(layer)

    if max_pooling:
        next_layer = tf.keras.layers.MaxPooling3D(pool_size = (2, 2, 2))(layer)    
    else:
        next_layer = layer
        
    skip_connection = layer    
    
    return next_layer, skip_connection

def decoder_block(prev_layer_input, skip_layer_input, n_filters=32):
    up = Conv3DTranspose(n_filters, (3, 3, 3), strides=(2, 2, 2), padding='same')(prev_layer_input)
    merge = concatenate([up, skip_layer_input], axis=4)

    conv = Conv3D(n_filters, 3,  activation='relu', padding='same')(merge)
    conv = Conv3D(n_filters, 3, activation='relu', padding='same')(conv)

    return conv

def unet_compiled(input_size, n_filters=32, n_classes=3, learning_rate=0.001):
    inputs = Input(input_size)

    cblock1 = encoder_block(inputs, n_filters, max_pooling=True)
    cblock2 = encoder_block(cblock1[0], n_filters * 2, max_pooling=True)
    cblock3 = encoder_block(cblock2[0], n_filters * 4, max_pooling=True)
    cblock4 = encoder_block(cblock3[0], n_filters * 8, max_pooling=True)
    cblock5 = encoder_block(cblock4[0], n_filters * 16, max_pooling=False) 
    
    ublock6 = decoder_block(cblock5[0], cblock4[1],  n_filters * 8)
    ublock7 = decoder_block(ublock6, cblock3[1],  n_filters * 4)
    ublock8 = decoder_block(ublock7, cblock2[1],  n_filters * 2)
    ublock9 = decoder_block(ublock8, cblock1[1],  n_filters)

    conv9 = Conv3D(n_filters, 3, activation='relu', padding='same',)(ublock9)
    conv10 = Conv3D(n_classes, 1, activation="Softmax", padding='same', )(conv9)
   
    model = tf.keras.Model(inputs=inputs, outputs=conv10)

    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), 
             loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[keras.metrics.SparseCategoricalAccuracy(name="sparse_categorical_accuracy")])

    return model

## Iseg 2019

In [None]:
iseg2019_training = mia_data.Iseg2019Processed(r'/ssd2/jupyter/MIA/split_datasets/iseg2019_ns/training')
iseg2019_validation = mia_data.Iseg2019Processed(r'/ssd2/jupyter/MIA/split_datasets/iseg2019_ns/validation')
iseg2019_testing = mia_data.Iseg2019Processed(r'/ssd2/jupyter/MIA/split_datasets/iseg2019_ns/testing')

In [None]:
dataset_specific_directory = os.path.join(model_directory, "iseg2019_ns")
if not os.path.exists(dataset_specific_directory):
    os.makedirs(dataset_specific_directory)

inference_directory = os.path.join(dataset_specific_directory, "inference")
if not os.path.exists(inference_directory):
    os.makedirs(inference_directory)

labels_dictionary = {"CSF" : 1, "GM" : 2, "WM" : 3}

In [None]:
t1 = iseg2019_training.subjects[0].get_T1()
t2 = iseg2019_training.subjects[0].get_T2()
label = iseg2019_training.subjects[0].get_label()

plot_side_by_side([sitk.GetArrayViewFromImage(t1), sitk.GetArrayViewFromImage(t2), sitk.GetArrayViewFromImage(label)], ["T1", "T2", "Label"] )

### T1 + T2

#### Setup

In [None]:
model_file = os.path.join(dataset_specific_directory, "iseg2019_T1T2.keras")

training = iseg2019_training.T1_T2_dataset()
validation = iseg2019_validation.T1_T2_dataset()
testing = iseg2019_testing.T1_T2_dataset()

batch_size = 1

input_shape = next(training.take(1).as_numpy_iterator())[0].shape
input_shape

#### Training

In [None]:
unet_creator = lambda : unet_compiled(input_size=input_shape,
                                       n_filters=16,
                                         n_classes=len(labels_dictionary) + 1,
                                           learning_rate=0.00001)

callbacks_list = [
    keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=50),
]

fit_params = {
    "callbacks" : callbacks_list,
    "training_data" : training.batch(batch_size),
    "validation_data" : validation.batch(batch_size),
    'epochs' : 500
}

seed_all()
best_model_path, history = run_training(unet_creator, fit_params, 10, dataset_specific_directory, "iseg2019_T1T2")

if os.path.exists(model_file):
    os.remove(model_file)
    
shutil.copy(best_model_path, model_file)

In [None]:
show_history(history)

#### Inference

In [None]:
unet = keras.saving.load_model(model_file)

In [None]:
prediction = unet.predict(testing.batch(1))
predicted_labels = prediction.argmax(axis=-1)

mia_utils.writeImagesArray(predicted_labels, inference_directory, lambda x : f"T1_T2{x}.nii.gz", lambda x : iseg2019_testing.subjects[x].get_label())

In [None]:
mia_utils.interactive_display(predicted_labels[0], (1,3), "Iseg 2019 T1 + T2 Inference", cmap=custom_cmap)

In [None]:
mia_utils.interactive_display(next(testing.take(1).as_numpy_iterator())[1], (1,3), title="Iseg 2019 T1+T2 Ground truth", cmap=custom_cmap)

In [None]:
predicted_image = sitk.GetImageFromArray(predicted_labels[0])
truth_image = sitk.GetImageFromArray(testing.as_numpy_iterator().next()[1])

eval = mia_eval.evaluateImage(predicted_image, truth_image, labels_dictionary)

In [None]:
mia_eval.createRecord("iseg2019", eval)

In [None]:
tf.config.experimental.get_memory_info('GPU:0')

## BONBID-HIE 2023

In [None]:
bondid2023_training = mia_data.BONDID2023Processed(r'/ssd2/jupyter/MIA/split_datasets/bonbid2023_ns/training', "nii.gz")
bondid2023_validation = mia_data.BONDID2023Processed(r'/ssd2/jupyter/MIA/split_datasets/bonbid2023_ns/validation',  "nii.gz")
bondid2023_testing = mia_data.BONDID2023Processed(r'/ssd2/jupyter/MIA/split_datasets/bonbid2023_ns/testing',  "nii.gz")

In [None]:
dataset_specific_directory = os.path.join(model_directory, "bonbid2023_ns")
if not os.path.exists(dataset_specific_directory):
    os.makedirs(dataset_specific_directory)

inference_directory = os.path.join(dataset_specific_directory, "inference")
if not os.path.exists(inference_directory):
    os.makedirs(inference_directory)

labels_dictionary = {"Lesion" : 1}

In [None]:
adc = bondid2023_training.subjects[0].get_ADC_ss()
z_adc = bondid2023_training.subjects[0].get_Z_ADC()
label = bondid2023_training.subjects[0].get_label()

plot_side_by_side([sitk.GetArrayViewFromImage(adc), sitk.GetArrayViewFromImage(z_adc), sitk.GetArrayViewFromImage(label)], ["ADC_ss", "Z_ADC", "Label"] )

### ADC + Z_ADC


#### Setup

In [None]:
model_file = os.path.join(dataset_specific_directory, "bonbid2023_adc_zadc.keras")

training = bondid2023_training.ADC_ss_Z_ADC_dataset()
testing = bondid2023_testing.ADC_ss_Z_ADC_dataset()
validation = bondid2023_validation.ADC_ss_Z_ADC_dataset()
batch_size = 2

input_shape = next(training.take(1).as_numpy_iterator())[0].shape
input_shape

#### Training

In [None]:
callbacks_list = [
    keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=10),
    keras.callbacks.ModelCheckpoint(filepath=model_file, monitor='val_loss', save_best_only=True, mode='min')
]

unet_creator = lambda : unet_compiled(input_size=input_shape, n_filters=16, n_classes=len(labels_dictionary) + 1, learning_rate=0.000001)

callbacks_list = [
    keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=50),
]

fit_params = {
    "callbacks" : callbacks_list,
    "training_data" : training.batch(batch_size),
    "validation_data" : validation.batch(batch_size),
    'epochs' : 500
}

seed_all()
best_model_path, history = run_training(unet_creator, fit_params, 10, dataset_specific_directory, "bonbid2023_adc_zadc")

if os.path.exists(model_file):
    shutil.rmtree(model_file)
    
shutil.copy(best_model_path, model_file)

In [None]:
show_history(history)

#### Inference

In [None]:
unet = keras.saving.load_model(model_file)

In [None]:
prediction = unet.predict(testing.batch(1))
predicted_labels = prediction.argmax(axis=-1)
mia_utils.writeImagesArray(predicted_labels, inference_directory,
                            lambda x : f"ADC_Z_ADC{x}.nii.gz",
                            lambda x: bondid2023_testing.subjects[x].get_label())

In [None]:
mia_utils.interactive_display(predicted_labels[0], (0,1), "Bondid 2023 ADC + Z_ADC Inference", cmap=custom_cmap)

In [None]:
mia_utils.interactive_display(next(testing.take(4).as_numpy_iterator())[1], (0,1), title="Bondid 2023 ADC + Z_ADC Ground truth", cmap=custom_cmap)

In [None]:
predicted_image = sitk.GetImageFromArray(predicted_labels[0])
truth_image = sitk.GetImageFromArray(testing.as_numpy_iterator().next()[1])

eval = mia_eval.evaluateImage(predicted_image, truth_image, labels_dictionary)

In [None]:
mia_eval.createRecord("bondid2023", eval)

In [None]:
tf.config.experimental.get_memory_info('GPU:0')

## TACR 6 HIE

In [None]:
tacrhie_training = mia_data.TACRHIE6Dataset(r'/ssd2/jupyter/MIA/split_datasets/tacrhie/training', "nii.gz")
tacrhie_validation = mia_data.TACRHIE6Dataset(r'/ssd2/jupyter/MIA/split_datasets/tacrhie/validation',  "nii.gz")
tacrhie_testing = mia_data.TACRHIE6Dataset(r'/ssd2/jupyter/MIA/split_datasets/tacrhie/testing',  "nii.gz")

### aseg

In [None]:
target_shape = (128, 128, 128)

training = mia_data.CroppedDataset(tacrhie_training.aseg_dataset(), target_shape).dataset()
validation = mia_data.CroppedDataset(tacrhie_validation.aseg_dataset(), target_shape).dataset()
testing = mia_data.CroppedDataset(tacrhie_testing.aseg_dataset(), target_shape).dataset()

In [None]:
import json

dataset_specific_directory = os.path.join(model_directory, "tacrhie", "aseg")
os.makedirs(dataset_specific_directory, exist_ok=True)

inference_directory = os.path.join(dataset_specific_directory, "inference")
os.makedirs(inference_directory, exist_ok=True)

with open(r'/ssd2/jupyter/MIA/split_datasets/tacrhie/aseg_labels.json') as file:
    labels_dictionary = json.load(file)
    del labels_dictionary['0']

    labels_dictionary = {v : float(k) for k,v in labels_dictionary.items()}

with open(r'/ssd2/jupyter/MIA/split_datasets/tacrhie/aseg_colors.json') as file:
    colours_dictionary = json.load(file)
    del colours_dictionary['0']
        
    custom_cmap = ListedColormap([np.array(colours_dictionary[str(k)][:3]) / 255.0 for k in sorted(int(a) for a in colours_dictionary.keys())])

In [None]:
norm = tacrhie_training.subjects[0].get_norm()
aseg = tacrhie_training.subjects[0].get_aseg()
aseg_aparc = tacrhie_training.subjects[0].get_aseg_aparc()

plot_side_by_side([sitk.GetArrayViewFromImage(norm), sitk.GetArrayViewFromImage(aseg), sitk.GetArrayViewFromImage(aseg_aparc)], ["norm", "aseg", "aseg+aparc"] )

### aseg


#### Setup

In [None]:
model_file = os.path.join(dataset_specific_directory, "tacrhie6_aseg.keras")

num_channels = 1
batch_size = 2

input_shape = (*next(training.take(1).as_numpy_iterator())[0].shape, num_channels)
input_shape

#### Training

In [None]:
unet_creator = lambda : unet_compiled(input_size=input_shape, n_filters=32, n_classes=len(labels_dictionary) + 1, learning_rate=0.00001)

callbacks_list = [
    keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=50),
]

fit_params = {
    "callbacks" : callbacks_list,
    "training_data" : training.batch(batch_size),
    "validation_data" : validation.batch(batch_size),
    'epochs' : 500
}

seed_all()
best_model_path, history = run_training(unet_creator, fit_params, 10, dataset_specific_directory, "tacrhie_aseg")

if os.path.exists(model_file):
    os.remove(model_file)
    
shutil.copy(best_model_path, model_file)

In [None]:
show_history(history)

#### Inference

In [None]:
unet = keras.saving.load_model(model_file)

In [None]:
prediction = unet.predict(testing.batch(1))
predicted_labels = prediction.argmax(axis=-1).astype(np.uint8)

shape = (predicted_labels.shape[0], 256, 256, 256)
begin = (np.array(shape) - predicted_labels.shape) // 2
begin[0] = 0

results = mia_utils.embed_tensor(predicted_labels, shape, begin)

mia_utils.writeImagesArray(results, inference_directory, 
                           lambda x : f"{tacrhie_testing.subjects[x].number}.nii.gz",
                            lambda x : tacrhie_testing.subjects[x].get_aseg())

In [None]:
evaluations = []

for index, img in enumerate(testing):
    prediction = unet.predict(img)
    predicted_labels = tf.math.argmax(prediction, -1).numpy()
    shape = (predicted_labels.shape[0], 256, 256, 256)
    begin = (np.array(shape) - predicted_labels.shape) // 2
    begin[0] = 0

    metrics = mia_eval.evaluateImage(sitk.GetImageFromArray(predicted_labels), sitk.GetImageFromArray(img[1].numpy()), labels_dictionary)
    evaluations.append((str(index), metrics))
    
    mia_utils.writeImageArray(predicted_labels, 
                              os.path.join(inference_directory, f"ADC_Z_ADC_{index}.nii.gz"),
                               lambda x: bondid2023_testing.subjects[x].get_label())

In [None]:
tf.config.experimental.get_memory_info('GPU:0')

In [None]:
mia_utils.interactive_display(predicted_labels[0], (1,256), title="TACR-HIE aseg Inference", cmap=custom_cmap)

In [None]:
truth_array = next(testing.take(1).as_numpy_iterator())[1]

mia_utils.interactive_display(truth_array, (1,256), title="TACR-HIE aseg Ground truth", cmap=custom_cmap)

In [None]:
predicted_image = sitk.GetImageFromArray(predicted_labels[0])
truth_image = sitk.GetImageFromArray(truth_array)

eval = mia_eval.evaluateImage(predicted_image, truth_image, labels_dictionary)

In [None]:
mia_eval.createRecord("tarc_aseg", eval)

In [None]:
tf.config.experimental.get_memory_info('GPU:0')

### aseg + aparc

In [None]:
target_shape = (128, 128, 128)

training = mia_data.CroppedDataset(tacrhie_training.aseg_aparc_dataset(), target_shape).dataset()
validation = mia_data.CroppedDataset(tacrhie_validation.aseg_aparc_dataset(), target_shape).dataset()
testing = mia_data.CroppedDataset(tacrhie_testing.aseg_aparc_dataset(), target_shape).dataset()

In [None]:
import json

dataset_specific_directory = os.path.join(model_directory, "tacrhie", "aseg_aparc")
os.makedirs(dataset_specific_directory, exist_ok=True)

inference_directory = os.path.join(dataset_specific_directory, "inference")
os.makedirs(inference_directory, exist_ok=True)

with open(r'/ssd2/jupyter/MIA/split_datasets/tacrhie/aseg_aparc_labels.json') as file:
    labels_dictionary = json.load(file)
    del labels_dictionary['0']

    labels_dictionary = {v : float(k) for k,v in labels_dictionary.items()}

with open(r'/ssd2/jupyter/MIA/split_datasets/tacrhie/aseg_aparc_colors.json') as file:
    colours_dictionary = json.load(file)
    del colours_dictionary['0']
        
    custom_cmap = ListedColormap([np.array(colours_dictionary[str(k)][:3]) / 255.0 for k in sorted(int(a) for a in colours_dictionary.keys())])

In [None]:
norm = tacrhie_training.subjects[0].get_norm()
aseg = tacrhie_training.subjects[0].get_aseg()
aseg_aparc = tacrhie_training.subjects[0].get_aseg_aparc()

plot_side_by_side([sitk.GetArrayViewFromImage(norm), sitk.GetArrayViewFromImage(aseg), sitk.GetArrayViewFromImage(aseg_aparc)], ["norm", "aseg", "aseg+aparc"] )

### aseg + aparc


#### Setup

In [None]:
model_file = os.path.join(dataset_specific_directory, "tacrhie6_aseg_aparc.keras")

num_channels = 1
batch_size = 2

input_shape = (*next(training.take(1).as_numpy_iterator())[0].shape, num_channels)
input_shape

#### Training

In [None]:
unet_creator = lambda : unet_compiled(input_size=input_shape, n_filters=24, n_classes=len(labels_dictionary) + 1, learning_rate=0.0001)

callbacks_list = [
    keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=50),
]

fit_params = {
    "callbacks" : callbacks_list,
    "training_data" : training.batch(batch_size),
    "validation_data" : validation.batch(batch_size),
    'epochs' : 500
}

seed_all()
best_model_path, history = run_training(unet_creator, fit_params, 10, dataset_specific_directory, "tacrhie_aseg_aparc")

if os.path.exists(model_file):
    os.remove(model_file)
    
shutil.copy(best_model_path, model_file)

In [None]:
show_history(history)

#### Inference

In [None]:
unet = keras.saving.load_model(model_file)

In [None]:
prediction = unet.predict(testing.batch(1))
predicted_labels = prediction.argmax(axis=-1).astype(np.uint8)

shape = (predicted_labels.shape[0], 256, 256, 256)
begin = (np.array(shape) - predicted_labels.shape) // 2
begin[0] = 0

results = mia_utils.embed_tensor(predicted_labels, shape, begin)

mia_utils.writeImagesArray(results, inference_directory, 
                           lambda x : f"{tacrhie_testing.subjects[x].number}.nii.gz",
                            lambda x : tacrhie_testing.subjects[x].get_aseg())

In [None]:
tf.config.experimental.get_memory_info('GPU:0')

In [None]:
mia_utils.interactive_display(predicted_labels[0], (1, 256), title="TACR-HIE aseg+aparc Inference", cmap=custom_cmap)

In [None]:
truth_array = next(testing.take(1).as_numpy_iterator())[1]

mia_utils.interactive_display(truth_array, (1, 256), title="TACR-HIE aseg+aparc  Ground truth", cmap=custom_cmap)

In [None]:
predicted_image = sitk.GetImageFromArray(predicted_labels[0])
truth_image = sitk.GetImageFromArray(truth_array)

eval = mia_eval.evaluateImage(predicted_image, truth_image, labels_dictionary)

In [None]:
mia_eval.createRecord("tarc_aseg_aparc", eval)

In [None]:
tf.config.experimental.get_memory_info('GPU:0')