In [1]:
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
import scipy

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint

import random

In [2]:
import layers as custom_layers

In [3]:
# We renamed cusolver64_11.dll to cusolver64_10.dll to solve the compatibility issue.
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [4]:
# File Paths
imgIdxCsvPath = './MRNet/MRNet-v1.0/similar.csv'
MRI_Path = './MRNet/MRNet-v1.0/train/axial/{}.npy'

## Load self-generated training data (by data loaders)

### Duplicate fixed image as labels

In [5]:
trainDataPath = "./affineTrainingData/affine{}.npz"
trainDataSize = 200
fixedImg = np.load(MRI_Path.format("0701"))
fixedImg = np.expand_dims(fixedImg, axis=-1)
fixedImg = fixedImg.astype('float')
fixedImg = np.expand_dims(fixedImg, axis=0)

def data_generator(batchSize = 1):
    # TODO: Batchify the funciton.
    while True:
        idx = random.randrange(trainDataSize)
        inputObj = np.load(trainDataPath.format(idx))
        movingImg = inputObj['img']
        movingImg = np.expand_dims(movingImg, axis=-1)
        movingImg = movingImg.astype('float')
        movingImg = np.expand_dims(movingImg, axis=0)
        imgPair = np.concatenate([movingImg, fixedImg], axis=4)

        tgtAffineTrf = inputObj['trf']
        tgtAffineTrf = np.expand_dims(tgtAffineTrf, axis=0)
        
        
        yield (imgPair, [fixedImg, tgtAffineTrf])

In [None]:
# Get rid of the batch dimension
imgPairShape = list(fixedImg.shape)[1:]
# Change the channel dimension to get the image pair shape
imgPairShape[-1] = 2
inputs = keras.Input(shape = imgPairShape)
movingImg = tf.expand_dims(inputs[:, :, :, :, 0], axis = -1)
phOutputImg = tf.expand_dims(inputs[:, :, :, :, 1], axis = -1, name="phOutputImg")
print(phOutputImg.shape)

conv_temp_0 = layers.Conv3D(filters = 1, kernel_size = (10, 50, 50), strides = (2, 5, 5), activation = "relu")(inputs)
filtered_conv_temp_0 = layers.MaxPool3D((2, 2, 2))(conv_temp_0)
print(filtered_conv_temp_0.shape)
flattened_temp = layers.Flatten()(filtered_conv_temp_0)
affine_temp = layers.Dense(12, activation="relu")(flattened_temp)

model_temp = keras.Model(inputs=inputs, outputs=[phOutputImg, affine_temp], name="combined_model")
model_temp.summary()

In [None]:
keras.utils.plot_model(model_temp, "temp.png", show_shapes=True)

In [None]:
lr = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate = lr)

In [None]:
model_temp.compile(optimizer=optimizer,
              loss={"tf.expand_dims_1":"mean_squared_error", "dense":"mean_squared_error"},
              loss_weights={"tf.expand_dims_1":1, "dense":10000})

In [None]:
dataGen = data_generator()

In [None]:
model_temp.fit(dataGen, epochs=20, steps_per_epoch=200)

In [None]:
moving_test, label_test = next(dataGen)
(warped_test, affine_pred_test) = model_temp(moving_test)
print(label_test[1])
print(affine_pred_test)

## NN with labels as fixed images

In [11]:
c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
print(type(c))
print(type(fixedImg))

<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'numpy.ndarray'>


In [41]:
# Get rid of the batch dimension
imgPairShape = list(fixedImg.shape)[1:]
# Change the channel dimension to get the image pair shape
imgPairShape[-1] = 2
inputs = keras.Input(shape = imgPairShape)
print(inputs.shape)
movingImg = tf.expand_dims(inputs[:, :, :, :, 0], axis = -1)
print(movingImg.shape)

down_depths = [2, 8, 32, 8, 1]
up_depths = [2, 16, 32, 16, 3]

conv_0 = layers.Conv3D(filters=down_depths[0], kernel_size=(3, 10, 10), activation="relu")(inputs)
print("conv_0: {}".format(conv_0.shape))
conv_1 = layers.Conv3D(filters=down_depths[1], kernel_size=(3, 15, 15), activation="relu")(conv_0)
print("conv_1: {}".format(conv_1.shape))
filtered_conv_1 = layers.MaxPool3D((2, 2, 2))(conv_1)
print("filtered_conv_1: {}".format(filtered_conv_1.shape))
conv_2 = layers.Conv3D(filters=down_depths[2], kernel_size=(3, 15, 15), activation="relu")(filtered_conv_1)
print("conv_2: {}".format(conv_2.shape))
filtered_conv_2 = layers.MaxPool3D((1, 2, 2))(conv_2)
print("filtered_conv_2: {}".format(filtered_conv_2.shape))
conv_3 = layers.Conv3D(filters=down_depths[3], kernel_size=(6, 20, 20), activation="relu")(filtered_conv_2)
print("conv_3: {}".format(conv_3.shape))
conv_4 = layers.Conv3D(filters=down_depths[4], kernel_size=(6, 10, 10), activation="relu")(conv_3)
print("conv_4: {}".format(conv_4.shape))
flattened = layers.Flatten()(conv_4)
dense_0 = layers.Dense(256, activation="relu")(flattened)
rotation_scaling = layers.Dense(9, activation="relu")(dense_0)
translation = layers.Dense(3, activation="relu")(dense_0)
'''
Using tf functions here. Maybe replace them with Keras layer ops?
'''
rot_scl_reshaped = tf.reshape(rotation_scaling, (-1, 3, 3))
trans_reshaped = tf.reshape(translation, (-1, 3, 1))
affine_pred = layers.Concatenate(axis=-1)([rot_scl_reshaped, trans_reshaped])
affine_pred = layers.Reshape((12,), name="affine_pred")(affine_pred)
print(affine_pred.shape)

convTransposed_3 = layers.Conv3DTranspose(filters=up_depths[0], kernel_size=(6, 10, 10), activation="relu")(conv_4)
print("convTransposed_3: {}".format(convTransposed_3.shape))
filtered_convTransposed_2 = layers.Conv3DTranspose(filters=up_depths[1], kernel_size=(6, 20, 20), activation="relu")(convTransposed_3)
print("filtered_convTransposed_2: {}".format(filtered_convTransposed_2.shape))
convTransposed_2 = layers.UpSampling3D((1, 2, 2))(filtered_convTransposed_2)
print("convTransposed_2: {}".format(convTransposed_2.shape))
filtered_convTransposed_1 = layers.Conv3DTranspose(filters=up_depths[2], kernel_size=(3, 15, 15), activation="relu")(convTransposed_2)
print("filtered_convTransposed_1: {}".format(filtered_convTransposed_1.shape))
convTransposed_1 = layers.UpSampling3D((2, 2, 2))(filtered_convTransposed_1)
print("convTransposed_1: {}".format(convTransposed_1.shape))
convTransposed_0 = layers.Conv3DTranspose(filters=up_depths[3], kernel_size=(3, 16, 16), activation="relu")(convTransposed_1)
print("convTransposed_0: {}".format(convTransposed_0.shape))
deformation_field_pred = layers.Conv3DTranspose(filters=up_depths[4], kernel_size=(3, 10, 10), activation="relu")(convTransposed_0)
print("deformation_field_pred: {}".format(deformation_field_pred.shape))

affine_warped = custom_layers.SpatialTransformer(interp_method='linear')([movingImg, affine_pred])
deformable_warped = custom_layers.SpatialTransformer(interp_method='linear', add_identity=False, name="warped_image")([affine_warped, deformation_field_pred])
print(deformable_warped.shape)

(None, 52, 256, 256, 2)
(None, 52, 256, 256, 1)
conv_0: (None, 50, 247, 247, 2)
conv_1: (None, 48, 233, 233, 8)
filtered_conv_1: (None, 24, 116, 116, 8)
conv_2: (None, 22, 102, 102, 32)
filtered_conv_2: (None, 22, 51, 51, 32)
conv_3: (None, 17, 32, 32, 8)
conv_4: (None, 12, 23, 23, 1)
(None, 12)
convTransposed_3: (None, 17, 32, 32, 2)
filtered_convTransposed_2: (None, 22, 51, 51, 16)
convTransposed_2: (None, 22, 102, 102, 16)
filtered_convTransposed_1: (None, 24, 116, 116, 32)
convTransposed_1: (None, 48, 232, 232, 32)
convTransposed_0: (None, 50, 247, 247, 16)
deformation_field_pred: (None, 52, 256, 256, 3)
(None, 52, 256, 256, 1)


In [42]:
model = keras.Model(inputs=inputs, outputs=[deformable_warped, affine_pred], name="combined_model")
model.summary()

Model: "combined_model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_26 (InputLayer)           [(None, 52, 256, 256 0                                            
__________________________________________________________________________________________________
conv3d_125 (Conv3D)             (None, 50, 247, 247, 1202        input_26[0][0]                   
__________________________________________________________________________________________________
conv3d_126 (Conv3D)             (None, 48, 233, 233, 10808       conv3d_125[0][0]                 
__________________________________________________________________________________________________
max_pooling3d_50 (MaxPooling3D) (None, 24, 116, 116, 0           conv3d_126[0][0]                 
_____________________________________________________________________________________

In [None]:
keras.utils.plot_model(model, "basic_unet.png", show_shapes=True)

In [25]:
lr = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate = lr)
# loss_object = tf.keras.losses.MeanSquaredError()

loss_history = []
save_callback = ModelCheckpoint('./checkpoints/{epoch:02d}.h5')

In [None]:
"""
y = (deformed_img, affine_trf)
"""
# def custom_loss(y_actual, y_pred):
#     fixed_img = y_actual[0]
#     deformed_img = y_pred[0]
#     tgtAffineTrf = y_actual[1]
#     predAffineTrf = y_pred[1]
#     mse = tf.keras.losses.MeanSquaredError()
#     deformation_loss = mse(fixed_img, deformed_img)
#     affine_loss = mse(tgtAffineTrf, predAffineTrf)
#     return deformation_loss + affine_loss

In [43]:
model.compile(optimizer=optimizer,
              loss={"warped_image":"mean_squared_error", "affine_pred":"mean_squared_error"},
              loss_weights={"warped_image":1, "affine_pred":10000})

In [44]:
dataGen = data_generator()

In [45]:
model.fit(dataGen, epochs=20, steps_per_epoch=trainDataSize, callbacks=[save_callback])

Epoch 1/20
 18/200 [=>............................] - ETA: 38:21 - loss: 693516.6311 - warped_image_loss: 3931.4943 - affine_pred_loss: 68.9585

KeyboardInterrupt: 

## Test model (output == warpedImg)

In [None]:
testModel = keras.Model(inputs=inputs, outputs=[deformable_warped, affine_pred])

In [None]:
# testModel.load_weights('./checkpoints/epoch_{}'.format(epochs-1))
testModel.load_weights('./checkpoints/03.h5')

In [None]:
print(testModel.layers[17].get_weights()[0].shape)

In [None]:
moving_test, label_test = next(dataGen)
(warped_test, affine_pred_test) = testModel(moving_test)
print(label_test[1])
print(affine_pred_test)

In [None]:
print(warped_test.shape)
print(type(moving_test))
print(np.sum(warped_test))

In [None]:
sliceToCheck = 0
fig, axs = plt.subplots(5, 3, figsize=(15, 35))
for i in range(5):
    axs[i, 0].imshow(fixedImg[0,sliceToCheck + i * 5,:,:,0])
    axs[i, 0].set_title("Fixed Image")
    axs[i, 1].imshow(moving_test[0,sliceToCheck + i * 5,:,:,0])
    axs[i, 1].set_title("Moving Image")
    axs[i, 2].imshow(warped_test[0,sliceToCheck + i * 5,:,:,0])
    axs[i, 2].set_title("Warped Image")
plt.show()