In [None]:
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
import scipy

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint

import random

In [None]:
import layers as custom_layers

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
# File Paths
imgIdxCsvPath = './MRNet/MRNet-v1.0/similar.csv'
MRI_Path = './MRNet/MRNet-v1.0/train/axial/{}.npy'

In [None]:
# Get indices of qualified training images in MRNet.
imageIndices = []
with open(imgIdxCsvPath) as file:
    fileReader = csv.reader(file)
    # Find all the rows with abnormal as 0 and store their indices
    for row in fileReader:
        index = str(row[0])
        while(len(index) < 4):
            index = '0' + index
        imageIndices.append(index)

In [None]:
print(imageIndices)

In [None]:
# Check which image has most slices, then we use it as our atlas/target/fixed image.
maxSlice = 0
targetIndex = None
for index in imageIndices:
    img = np.load(MRI_Path.format(index))
    if (img.shape[0] > maxSlice):
        targetIndex = index
        maxSlice = img.shape[0]
    
print(maxSlice)
print(targetIndex)

## Load all qualified images into memory

In [None]:
# Load fixed image.
fixedImg = np.load(MRI_Path.format(targetIndex))

In [None]:
# Load moving images and zero-pad them.
movingImgs = []
# limit = 30
# counter = 0
for index in imageIndices:
    if index == targetIndex:
        continue
#     if counter == limit:
#         break
    img = np.load(MRI_Path.format(index))
    numSlicesToPaddle = maxSlice - img.shape[0]
    img = np.pad(img, ((0, numSlicesToPaddle), (0, 0), (0, 0)))
    movingImgs.append(img)
#     counter += 1

In [None]:
# Append the channel axis to moving images and the fixed image.
fixedImg = np.expand_dims(fixedImg, axis=-1)
for i in range(len(movingImgs)):
    movingImgs[i] = np.expand_dims(movingImgs[i], axis=-1)
print(movingImgs[0].shape)
print(fixedImg.shape)

In [None]:
movingImgs = np.array(movingImgs)
print(movingImgs.shape)

In [None]:
fixedImgs = np.repeat(fixedImg[np.newaxis,...], movingImgs.shape[0], axis = 0)
print(fixedImgs.shape)

In [None]:
print(movingImgs.dtype)

In [None]:
movingImgs = movingImgs.astype('float32')
fixedImgs = fixedImgs.astype('float32')

In [None]:
print(movingImgs.dtype)

In [None]:
dataset = tf.data.Dataset.from_tensor_slices((movingImgs, fixedImgs))
batch_size = 1
dataset.batch(batch_size)

In [None]:
print(dataset.element_spec)

## Load data by data generators

In [None]:
fixedImg = np.load(MRI_Path.format(targetIndex))
fixedImg = np.expand_dims(fixedImg, axis=-1)
fixedImg = fixedImg.astype('float32')
def data_generator(batchSize = 1):
    while True:
        movingImgs = []
        fixedImgs = []
        for i in range(batchSize):
            idx = random.randrange(len(imageIndices))
            fileIdx = imageIndices[idx]
            movingImg = np.load(MRI_Path.format(fileIdx))
            numSlicesToPaddle = maxSlice - movingImg.shape[0]
            movingImg = np.pad(movingImg, ((0, numSlicesToPaddle), (0, 0), (0, 0)))
            movingImg = np.expand_dims(movingImg, axis=-1)
            movingImg = movingImg.astype('float32')
            movingImgs.append(movingImg)
            fixedImgs.append(fixedImg)
        movingImgs = np.array(movingImgs)
        fixedImgs = np.array(fixedImgs)
        yield (movingImgs, fixedImgs)

In [None]:
dataGen = data_generator()

In [None]:
# Construct NN
# nFeats = 1
# for i in range(len(fixedImg.shape)):
#     nFeats *= fixedImg.shape[i]
# print(nFeats)
inputs = keras.Input(shape = fixedImg.shape)
print(inputs.shape)
conv_0 = layers.Conv3D(filters=2, kernel_size=(3, 10, 10), activation="relu")(inputs)
conv_1 = layers.Conv3D(filters=3, kernel_size=(3, 15, 15), activation="relu")(conv_0)
filtered_conv_1 = layers.MaxPool3D((2, 2, 2))(conv_1)
conv_2 = layers.Conv3D(filters=4, kernel_size=(3, 15, 15), activation="relu")(filtered_conv_1)
filtered_conv_2 = layers.MaxPool3D((1, 2, 2))(conv_2)
print(filtered_conv_2.shape)
conv_3 = layers.Conv3D(filters=2, kernel_size=(6, 20, 20), activation="relu")(filtered_conv_2)
conv_4 = layers.Conv3D(filters=1, kernel_size=(6, 10, 10), activation="relu")(conv_3)
print(conv_4.shape)
flattened = layers.Flatten()(conv_4)
print(flattened.shape)
dense_0 = layers.Dense(256, activation="relu")(flattened)
affine_pred = layers.Dense(12, activation="relu")(dense_0)
warped = custom_layers.SpatialTransformer(interp_method='linear')([inputs, affine_pred])
print(warped.shape)

In [None]:
model = keras.Model(inputs=inputs, outputs=warped, name="affine_model")
model.summary()

In [None]:
lr = 0.001
optimizer = tf.keras.optimizers.Adam(learning_rate = lr)
loss_object = tf.keras.losses.MeanSquaredError()

loss_history = []
save_callback = ModelCheckpoint('./checkpoints/{epoch:02d}.h5')

In [None]:
model.compile(optimizer=optimizer, loss=loss_object)

In [None]:
model.fit(dataGen, epochs=50, steps_per_epoch=50, callbacks=[save_callback])

In [None]:
'''
Training step for each epoch
'''
def train_step(movingImages, fixedImages):
  with tf.GradientTape() as tape:
    warpedImages = model(movingImages, training=True)

    # Add asserts to check the shape of the output.
    #tf.debugging.assert_equal(logits.shape, (32, 10))

    loss_value = loss_object(warpedImages, fixedImages)

  loss_history.append(loss_value.numpy().mean())
  grads = tape.gradient(loss_value, model.trainable_variables)
  optimizer.apply_gradients(zip(grads, model.trainable_variables))

In [None]:
def train(epochs):
  for epoch in range(epochs):
    for (batch, (movingImages, fixedImages)) in enumerate(dataset):
        movingImages = movingImages[np.newaxis,...]
        fixedImages = fixedImages[np.newaxis,...]
        train_step(movingImages, fixedImages)
    print ('Epoch {} finished. Current loss value: {}'.format(epoch, loss_history[-1]))
    model.save_weights('./checkpoints/epoch_{}'.format(epoch))

In [None]:
epochs = 50
train(epochs)

In [None]:
testModel = keras.Model(inputs=inputs, outputs=warped)

In [None]:
testModel.load_weights('./checkpoints/epoch_{}'.format(epochs-1))

In [None]:
warped_test = testModel(movingImgs[0:1])

In [None]:
print(warped_test.shape)

In [None]:
sliceToCheck = 30
fig, axs = plt.subplots(1, 3, figsize=(15, 15))
axs[0].imshow(fixedImg[sliceToCheck,:,:,0])
axs[0].set_title("Fixed Image")
axs[1].imshow(movingImgs[0,sliceToCheck,:,:,0])
axs[1].set_title("Moving Image")
axs[2].imshow(warped_test[0,sliceToCheck,:,:,0])
axs[2].set_title("Warped Image")
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.plot(range(len(loss_history)), loss_history)

In [None]:
n = np.array([1., 2., 4.])
t = tf.constant(n)
print(t.shape)
print(t.shape.as_list())