In [None]:
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
import scipy

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint

import random

In [None]:
import layers as custom_layers

In [None]:
# We renamed cusolver64_11.dll to cusolver64_10.dll to solve the compatibility issue.
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
# File Paths
imgIdxCsvPath = './MRNet/MRNet-v1.0/similar.csv'
MRI_Path = './MRNet/MRNet-v1.0/train/axial/{}.npy'

## Load self-generated training data (by data loaders)

### Duplicate fixed image as labels

In [None]:
trainDataPath = "./affineTrainingData/affine{}.npz"
trainDataSize = 2000
fixedImg = np.load(MRI_Path.format("0701"))
fixedImg = np.expand_dims(fixedImg, axis=-1)
fixedImg = fixedImg.astype('float')
fixedImg = np.expand_dims(fixedImg, axis=0)

def data_generator(batchSize = 1):
    # TODO: Batchify the funciton.
    while True:
        idx = random.randrange(trainDataSize)
        inputObj = np.load(trainDataPath.format(idx))
        movingImg = inputObj['img']
        movingImg = np.expand_dims(movingImg, axis=-1)
        movingImg = movingImg.astype('float')
        movingImg = np.expand_dims(movingImg, axis=0)

        tgtAffineTrf = inputObj['trf']
        tgtAffineTrf = np.expand_dims(tgtAffineTrf, axis=0)
        
        
        yield ([fixedImg, tgtAffineTrf], movingImg)

## NN with labels as fixed images

In [None]:
# Get rid of the batch dimension
imgInput = keras.Input(shape = list(fixedImg.shape)[1:])
affineInput = keras.Input(shape = (12,))

affine_warped = custom_layers.SpatialTransformer(interp_method='linear', add_identity=False, indexing="xy")([imgInput, affineInput])
print(affine_warped.shape)

In [None]:
model = keras.Model(inputs=[imgInput, affineInput], outputs=affine_warped, name="combined_model")
model.summary()

In [None]:
keras.utils.plot_model(model, "test_model.png", show_shapes=True)

In [None]:
dataGen = data_generator()
input_test, label_test = next(dataGen)
warped_test = model(input_test)

In [None]:
print(warped_test.shape)
print(label_test.shape)

In [None]:
lr = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate = lr)
# loss_object = tf.keras.losses.MeanSquaredError()

loss_history = []
save_callback = ModelCheckpoint('./checkpoints/{epoch:02d}.h5')

In [None]:
model.compile(optimizer=optimizer,
              loss={"warped_image":"mean_squared_error", "affine_pred":affine_loss},
              loss_weights={"warped_image":1, "affine_pred":1},
              run_eagerly=True)

In [None]:
dataGen = data_generator()

In [None]:
model.fit(dataGen, epochs=20, steps_per_epoch=trainDataSize, callbacks=[save_callback])

## Test model (output == warpedImg)

In [None]:
testModel = keras.Model(inputs=inputs, outputs=[deformable_warped, affine_pred])

In [None]:
# testModel.load_weights('./checkpoints/epoch_{}'.format(epochs-1))
testModel.load_weights('./checkpoints/15.h5')

In [None]:
moving_test, label_test = next(dataGen)
(warped_test, affine_pred_test) = testModel(moving_test)
print(label_test[1])
print(affine_pred_test)

In [None]:
print(warped_test.shape)
print(type(moving_test))
print(np.sum(warped_test))

In [None]:
sliceToCheck = 0
fig, axs = plt.subplots(5, 3, figsize=(15, 35))
for i in range(5):
    axs[i, 0].imshow(fixedImg[0,sliceToCheck + i * 5,:,:,0])
    axs[i, 0].set_title("Fixed Image")
    axs[i, 1].imshow(warped_test[0,sliceToCheck + i * 5,:,:,0])
    axs[i, 1].set_title("Warped Image")
    axs[i, 2].imshow(label_test[0,sliceToCheck + i * 5,:,:,0])
    axs[i, 2].set_title("Original Moving Image")
plt.show()