In [None]:
%cd /content/drive/MyDrive/ISMRM_Code


In [2]:
# @title Imports

import UnetLib
import utils
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam

In [3]:
# @title Load Data

T1w = np.load("data/T1w.npy")
T2w = np.load("data/T2w.npy")
PDw = np.load("data/PDw.npy")
flair = np.load("data/flair.npy")
brain_mask = np.load("data/brain_mask.npy")

ref_T1_map = np.load("data/ref_T1_map.npy")
ref_T2_map = np.load("data/ref_T2_map.npy")
ref_PD_map = np.load("data/ref_PD_map.npy")

input_images = np.stack((T1w, T2w, PDw, flair), axis=-1)


In [None]:
#@title Define Model
model = UnetLib.UNet2D(row = 192,
       col = 160,
       kernel_size = 3,
       num_out_chan_highest_level=64,
       depth=5,
       num_chan_increase_rate=2,
       activation_type='relu',
       dropout_rate=0.05,
       USE_BN=True,
       SKIP_CONNECTION_AT_THE_END=True,
       num_input_chans=4,
       num_output_chans=3
)

In [None]:
#@title Training

optimizer = Adam(learning_rate=0.001)
model.compile(optimizer=optimizer)

# Training loop
tensor_input_images = tf.convert_to_tensor(input_images)
num_epochs = 500
num_epochs_with_scale = 4000
batch_size = 4
total_samples = len(input_images)

tensor_brain_mask  = tf.convert_to_tensor(brain_mask, dtype=tf.float64)
train_data = tf.concat([tensor_input_images, tf.expand_dims(tensor_brain_mask, axis=-1)], axis=-1)
train_dataset = tf.data.Dataset.from_tensor_slices(train_data)

scale_variables = tf.Variable(initial_value=tf.ones(4), dtype=tf.float32, trainable=True)

for epoch in range(num_epochs):

    epoch_loss = 0.0
    dataset = train_dataset.shuffle(16)
    dataset = dataset.batch(batch_size)

    for batch in dataset:
        batch_inputs = batch[:, :, :, :4]
        batch_mask = batch[:, :, :, -1]

        with tf.GradientTape() as tape:
            batch_predictions = model([batch_inputs], training=True)
            loss = utils.self_supervised_loss(batch_inputs, batch_predictions,batch_mask, scale_variables)

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        epoch_loss += loss.numpy()

    avg_epoch_loss = epoch_loss / np.ceil((total_samples / batch_size))
    print(f"Avg. Loss for Epoch {epoch+1}: {avg_epoch_loss:.5f}")

optimizer = Adam(learning_rate=0.01)
model.compile(optimizer=optimizer)
for epoch in range(num_epochs_with_scale):

    epoch_loss = 0.0
    dataset = train_dataset.shuffle(16)
    dataset = dataset.batch(batch_size)

    for batch in dataset:
        batch_inputs = batch[:, :, :, :4]
        batch_mask = batch[:, :, :, -1]

        with tf.GradientTape() as tape:
            batch_predictions = model([batch_inputs], training=True)
            loss = utils.self_supervised_loss(batch_inputs, batch_predictions,batch_mask, scale_variables)

        gradients = tape.gradient(loss, model.trainable_variables + [scale_variables])
        optimizer.apply_gradients(zip(gradients, model.trainable_variables + [scale_variables]))

        epoch_loss += loss.numpy()

    avg_epoch_loss = epoch_loss / np.ceil((total_samples / batch_size))
    print(f"Avg. Loss for Epoch {epoch+1}: {avg_epoch_loss:.5f}")

scale_variables = scale_variables.numpy()

In [4]:
#@title Load Trained Model & Scales

tf.keras.utils.get_custom_objects()['self_supervised_loss_scale'] = utils.self_supervised_loss
model = tf.keras.models.load_model('trained_model')
scale_variables = np.load('scale_variables.npy')

In [None]:
#@title Predictions
import tensorflow
y_pred = model.predict([input_images])

est_T1_map = y_pred[:, :, :, 0]*5000*brain_mask
est_T2_map = y_pred[:, :, :, 1]*500*brain_mask
est_PD_map = y_pred[:, :, :, 2]*brain_mask

weighted_test = utils.contrast_generation(est_T1_map,est_T2_map,est_PD_map,scale_variables)

test_loss = tensorflow.keras.losses.MeanSquaredError()(input_images, weighted_test)

print("MSE Loss:", test_loss.numpy())

In [None]:
# @title Plot Weighted Images

utils.imagesc(input_images[:, :, :, 0], 6, [0, 1], "Input T1w")
utils.imagesc(weighted_test[:, :, :, 0].numpy(), 6, [0, 1.0], "Est T1w")

utils.imagesc(input_images[:, :, :, 1], 6, [0, 1], "Input T2w")
utils.imagesc(weighted_test[:, :, :, 1].numpy() * brain_mask, 6, [0, 1.0], "Est T2w")

utils.imagesc(input_images[:, :, :, 2], 6, [0, 1], "Input PDw")
utils.imagesc(weighted_test[:, :, :, 2].numpy() * brain_mask, 6, [0, 1.0], "Est PDw")

utils.imagesc(input_images[:, :, :, 3], 6, [0, 1], "Input FLAIR")
utils.imagesc(weighted_test[:, :, :, 3].numpy() * brain_mask, 6, [0, 1.0], "Est FLAIR")


In [None]:
# @title Plot Parameter Maps

utils.imagesc(ref_T1_map, 6, [0,3000], 'ref T1 map',cmap = "hot")
utils.imagesc(est_T1_map, 6, [0,3000], 'est_T1 map',cmap = "hot")

utils.imagesc(ref_T2_map, 6, [0,250], 'ref T2 map',cmap = "hot")
utils.imagesc(est_T2_map, 6, [0,250], 'est_T2 map',cmap = "hot")
utils.imagesc(ref_PD_map, 6, [0,1.0], 'ref PD map')
utils.imagesc(est_PD_map, 6, [0,1.0], 'est_PD map')