In [1]:
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
import scipy

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint

import random

In [2]:
import layers as custom_layers

In [3]:
# We renamed cusolver64_11.dll to cusolver64_10.dll to solve the compatibility issue.
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [4]:
# File Paths
imgIdxCsvPath = './MRNet/MRNet-v1.0/similar.csv'
MRI_Path = './MRNet/MRNet-v1.0/train/axial/{}.npy'

## Load self-generated training data (by data loaders)

### Duplicate fixed image as labels

In [5]:
trainDataPath = "./affineTrainingData/affine{}.npz"
trainDataSize = 2000
fixedImg = np.load(MRI_Path.format("0701"))
fixedImg = fixedImg / np.max(fixedImg)
fixedImg = fixedImg.astype('float32')
fixedImg = np.expand_dims(fixedImg, axis=-1)

In [6]:
"""
Returns a batch of training image pairs.
"""
def image_input_gen(batch_size = 5):
    while True:
        imgPair_batch = np.zeros((batch_size, *fixedImg.shape[:-1], 2))
        for i in range(batch_size):
            idx = random.randrange(trainDataSize)
            inputObj = np.load(trainDataPath.format(idx))
            movingImg = inputObj['img']
            movingImg = np.expand_dims(movingImg, axis=-1)
            movingImg = movingImg.astype('float32')
            imgPair = np.concatenate([movingImg, fixedImg], axis=3)
            imgPair_batch[i] = imgPair
        
        yield imgPair_batch

In [None]:
# def data_generator(batch_size = 5):
#     while True:
#         imgPair_batch = np.zeros((batch_size, *fixedImg.shape[:-1], 2))
#         fixedImg_batch = np.zeros((batch_size, *fixedImg.shape))
#         tgtAffineTrf_batch = np.zeros((batch_size, 12))
#         for i in range(batch_size):
#             idx = random.randrange(trainDataSize)
#             inputObj = np.load(trainDataPath.format(idx))
#             movingImg = inputObj['img']
#             movingImg = np.expand_dims(movingImg, axis=-1)
#             movingImg = movingImg.astype('float32')
#             imgPair = np.concatenate([movingImg, fixedImg], axis=3)
#             imgPair_batch[i] = imgPair

#             tgtAffineTrf = inputObj['trf']
#             tgtAffineTrf = tgtAffineTrf.astype('float32')
#             tgtAffineTrf_batch[i] = tgtAffineTrf
#             fixedImg_batch[i] = fixedImg
        
#         yield (imgPair_batch, [fixedImg_batch, tgtAffineTrf_batch])

### Define the Generator and the Discriminator

In [7]:
def conv_activation(x, alpha=0.3):
    x = layers.LeakyReLU(alpha=alpha)(x)
    x = layers.BatchNormalization()(x)
    return x

In [None]:
"""
Generator
"""
inputs = keras.Input(shape = (*fixedImg.shape[:-1], 2))
moving_input = tf.expand_dims(inputs[:, :, :, :, 0], axis = -1)
# fixed_input = tf.expand_dims(inputs[:, :, :, :, 1], axis = -1)

down_depths = [32, 64, 128, 256, 512]
up_depths = [2, 16, 32, 16, 3]
kernel_size = (7, 7, 7)

x = layers.Conv3D(filters=down_depths[0], kernel_size=(6, 32, 32), strides=(2, 2, 2), activation="relu")(inputs)
# x = layers.MaxPool3D((1, 2, 2))(x)
print("0: {}".format(x.shape))

x = layers.Conv3D(filters=down_depths[0], kernel_size=(6, 32, 32), strides=(1, 2, 2))(x)
x = conv_activation(x)
f_conv_1 = layers.Conv3D(filters=down_depths[1], kernel_size=(6, 16, 16), strides=(2, 2, 2), )(f_conv_0)
f_conv_1 = layers.BatchNormalization()(f_conv_1)
f_conv_2 = layers.Conv3D(filters=down_depths[2], kernel_size=(6, 8, 8), activation="relu")(f_conv_1)
f_conv_2 = layers.BatchNormalization()(f_conv_2)
f_conv_3 = layers.Conv3D(filters=down_depths[3], kernel_size=(4, 8, 8), strides=(1, 2, 2), activation="relu")(f_conv_2)
f_conv_3 = layers.BatchNormalization()(f_conv_3)
f_conv_4 = layers.Conv3D(filters=down_depths[4], kernel_size=(4, 4, 4), activation="relu")(f_conv_3)
f_conv_4 = layers.BatchNormalization()(f_conv_4)

# concat_feats = tf.concat([conv_4, f_conv_4], axis = -1)
# print(concat_feats.shape)

x = layers.AveragePooling3D((2, 2, 2))(x)
x = layers.Flatten()(x)
x = layers.Dense(512)(x)
x = layers.BatchNormalization(x)
x = layers.LeakyReLU()(x)
# x = layers.Dropout(0.2)(x)
x = layers.Dense(64)(x)
x = layers.BatchNormalization(x)
x = layers.LeakyReLU()(x)
affine_pred = layers.Dense(12, activation="linear", name="affine_pred")(x)

# convTransposed_3 = layers.Conv3DTranspose(filters=up_depths[0], kernel_size=(6, 10, 10), activation="relu")(conv_4)
# print("convTransposed_3: {}".format(convTransposed_3.shape))
# filtered_convTransposed_2 = layers.Conv3DTranspose(filters=up_depths[1], kernel_size=(6, 20, 20), activation="relu")(convTransposed_3)
# print("filtered_convTransposed_2: {}".format(filtered_convTransposed_2.shape))
# convTransposed_2 = layers.UpSampling3D((1, 2, 2))(filtered_convTransposed_2)
# print("convTransposed_2: {}".format(convTransposed_2.shape))
# filtered_convTransposed_1 = layers.Conv3DTranspose(filters=up_depths[2], kernel_size=(3, 15, 15), activation="relu")(convTransposed_2)
# print("filtered_convTransposed_1: {}".format(filtered_convTransposed_1.shape))
# convTransposed_1 = layers.UpSampling3D((2, 2, 2))(filtered_convTransposed_1)
# print("convTransposed_1: {}".format(convTransposed_1.shape))
# convTransposed_0 = layers.Conv3DTranspose(filters=up_depths[3], kernel_size=(3, 16, 16), activation="relu")(convTransposed_1)
# print("convTransposed_0: {}".format(convTransposed_0.shape))
# deformation_field_pred = layers.Conv3DTranspose(filters=up_depths[4], kernel_size=(3, 10, 10), activation="relu")(convTransposed_0)
# print("deformation_field_pred: {}".format(deformation_field_pred.shape))

affine_warped = custom_layers.SpatialTransformer(interp_method='linear', add_identity=False, name="warped_image", shift_center=True)([moving_input, affine_pred])
# deformable_warped = custom_layers.SpatialTransformer(interp_method='linear', add_identity=False, name="warped_image", shift_center=False)([affine_warped, deformation_field_pred])
print(affine_warped.shape)

In [None]:
model = keras.Model(inputs=inputs, outputs=[affine_warped, affine_pred], name="combined_model")
model.summary()

In [None]:
dataGen = data_generator(batch_size = 1)
moving_test, label_test = next(dataGen)
print(moving_test.shape)
(warped_test, affine_pred_test) = model(moving_test)
print(label_test[1])
print(label_test[1].dtype)
print(affine_pred_test)
mse = tf.keras.losses.MeanSquaredError()
print(mse(label_test[1], affine_pred_test))

In [None]:
keras.utils.plot_model(model, "test.png", show_shapes=True)

In [None]:
lr = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate = lr)
# loss_object = tf.keras.losses.MeanSquaredError()

loss_history = []
save_callback = ModelCheckpoint('./checkpoints/{epoch:02d}.h5')

In [None]:
"""
y = (deformed_img, affine_trf)
"""
def affine_loss(y_actual, y_pred):
    y_actual = y_actual[0]
    y_pred = y_pred[0]
    tgtDiag = [y_actual[0]] + [y_actual[5]] + [y_actual[10]]
    predDiag = [y_pred[0]] + [y_pred[5]] + [y_pred[10]]
    tgtCorner = y_actual[1:3] + [y_actual[4]] + [y_actual[6]] + y_actual[8:10]
    predCorner = y_pred[1:3] + [y_pred[4]] + [y_pred[6]] + y_pred[8:10]
    tgtTranslation = [y_actual[3]] + [y_actual[7]] + [y_actual[11]]
    predTranslation = [y_pred[3]] + [y_pred[7]] + [y_pred[11]]
    mse = tf.keras.losses.MeanSquaredError()
    diag_loss = mse(tgtDiag, predDiag)
    corner_loss = mse(tgtCorner, predCorner)
    translation_loss = mse(tgtTranslation, predTranslation)
    return corner_loss * 10 + diag_loss * 10 + translation_loss

In [None]:
model.compile(optimizer=optimizer,
              loss={"warped_image":"mean_squared_error", "affine_pred":"mean_squared_error"},
#               loss_weights={"warped_image":1, "affine_pred":1},
              run_eagerly=True)

In [None]:
dataGen = data_generator(batch_size=4)

In [None]:
model.fit(dataGen, epochs=20, steps_per_epoch=trainDataSize/4, callbacks=[save_callback])

## Test model (output == warpedImg)

In [None]:
testModel = keras.Model(inputs=inputs, outputs=[affine_warped, affine_pred])

In [None]:
# testModel.load_weights('./checkpoints/epoch_{}'.format(epochs-1))
testModel.load_weights('./checkpoints/13.h5')

In [None]:
dataGen_test = data_generator(batch_size = 1)
moving_test, label_test = next(dataGen_test)
(warped_test, affine_pred_test) = testModel(moving_test)
print(label_test[1])
print(affine_pred_test)
mse = tf.keras.losses.MeanSquaredError()
print(mse(label_test[1], affine_pred_test))

In [None]:
sliceToCheck = 0
fig, axs = plt.subplots(5, 3, figsize=(15, 35))
for i in range(5):
    axs[i, 0].imshow(fixedImg[sliceToCheck + i * 5,:,:,0])
    axs[i, 0].set_title("Fixed Image")
    axs[i, 1].imshow(moving_test[0,sliceToCheck + i * 5,:,:,0])
    axs[i, 1].set_title("Moving Image")
    axs[i, 2].imshow(warped_test[0,sliceToCheck + i * 5,:,:,0])
    axs[i, 2].set_title("Warped Image")
plt.show()