In [None]:
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
import scipy

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint

import random
import math

import layers as custom_layers

In [None]:
MRI_Path = './MRNet/MRNet-v1.0/train/axial/{}.npy'

In [None]:
fixedImg = np.load(MRI_Path.format('0701'))

In [None]:
inputImg = np.expand_dims(fixedImg, axis=-1)
inputImg = inputImg / np.max(inputImg)
inputImg = np.expand_dims(inputImg, axis=0)

In [None]:
# Get rid of the batch dimension
imgInput = keras.Input(shape = list(inputImg.shape)[1:])
affineInput = keras.Input(shape = (12,))

affine_warped = custom_layers.SpatialTransformer(interp_method='linear', add_identity=False, indexing="ij", shift_center=True)([imgInput, affineInput])
print(affine_warped.shape)

In [None]:
model = keras.Model(inputs=[imgInput, affineInput], outputs=affine_warped, name="affine_transformation")
model.summary()

In [None]:
for i in range(2000):
    transformedImg = np.zeros(fixedImg.shape)

    translation = np.eye(4)
    rotX = np.eye(4)
    rotY = np.eye(4)
    rotZ = np.eye(4)
    scaling = np.eye(4)

    xTranslationRange = 5
    yTranslationRange = 30
    zTranslationRange = 30
    translation[0, 3] = xTranslationRange * random.random()
    translation[1, 3] = yTranslationRange * random.random()
    translation[2, 3] = zTranslationRange * random.random()

    xRotAngle = math.pi / random.uniform(8, 80)
    rotX[1, 1] = math.cos(xRotAngle)
    rotX[1, 2] = math.sin(xRotAngle)
    rotX[2, 1] = -math.sin(xRotAngle)
    rotX[2, 2] = math.cos(xRotAngle)

    yRotAngle = math.pi / random.uniform(40, 160)
    rotY[0, 0] = math.cos(yRotAngle)
    rotY[0, 2] = -math.sin(yRotAngle)
    rotY[2, 0] = math.sin(yRotAngle)
    rotY[2, 2] = math.cos(yRotAngle)

    zRotAngle = math.pi / random.uniform(40, 160)
    rotZ[0, 0] = math.cos(zRotAngle)
    rotZ[0, 1] = -math.sin(zRotAngle)
    rotZ[1, 0] = math.sin(zRotAngle)
    rotZ[1, 1] = math.cos(zRotAngle)

    scalingX = random.uniform(0.85, 1.15)
    scalingY = random.uniform(0.85, 1.15)
    scalingZ = random.uniform(0.85, 1.15)
    scaling[0, 0] = scalingX
    scaling[1, 1] = scalingY
    scaling[2, 2] = scalingZ

    transMat = translation @ rotZ @ rotY @ rotX @ scaling

    affine_param = transMat[0:3].flatten()
    affine_param = np.expand_dims(affine_param, axis=0)
    
    transformedImg = model([inputImg, affine_param])
    transformedImg = tf.squeeze(transformedImg)
    
    transMat_inv = np.linalg.inv(transMat)
    trf_to_save = transMat_inv[0:3].flatten()
    
    path = "./affineTrainingData/affine{}.npz"
    f = open(path.format(i), "wb")
    np.savez(f, img = transformedImg, trf = trf_to_save)
    f.close()
    if (i % 100 == 0):
        print("Milestone: file {} has been saved.".format(i))

## Check generated training dataset

In [None]:
trainDataPath = "./affineTrainingData/affine{}.npz"
sample_idx = random.randrange(2000)
zip_obj = np.load(trainDataPath.format(sample_idx))
moving_img = zip_obj['img']
moving_img = np.expand_dims(moving_img, axis = -1)
moving_img = np.expand_dims(moving_img, axis = 0)
target_trf = zip_obj['trf']
target_trf = np.expand_dims(target_trf, axis = 0)
recon_img = model([moving_img, target_trf])

sliceToCheck = 25
fig, axs = plt.subplots(1, 3, figsize=(15, 15))
axs[0].imshow(fixedImg[sliceToCheck, :, :])
axs[1].imshow(moving_img[0, sliceToCheck, :, :, 0])
axs[2].imshow(recon_img[0, sliceToCheck, :, :, 0])

## For testing purposes: check random transformed image samples

In [None]:
transformedImg = np.zeros(fixedImg.shape)

translation = np.eye(4)
rotX = np.eye(4)
rotY = np.eye(4)
rotZ = np.eye(4)
scaling = np.eye(4)

xTranslationRange = 5
yTranslationRange = 30
zTranslationRange = 30
translation[0, 3] = xTranslationRange * random.random()
translation[1, 3] = yTranslationRange * random.random()
translation[2, 3] = zTranslationRange * random.random()

xRotAngle = math.pi / random.uniform(8, 80)
rotX[1, 1] = math.cos(xRotAngle)
rotX[1, 2] = math.sin(xRotAngle)
rotX[2, 1] = -math.sin(xRotAngle)
rotX[2, 2] = math.cos(xRotAngle)

yRotAngle = math.pi / random.uniform(40, 160)
rotY[0, 0] = math.cos(yRotAngle)
rotY[0, 2] = -math.sin(yRotAngle)
rotY[2, 0] = math.sin(yRotAngle)
rotY[2, 2] = math.cos(yRotAngle)

zRotAngle = math.pi / random.uniform(40, 160)
rotZ[0, 0] = math.cos(zRotAngle)
rotZ[0, 1] = -math.sin(zRotAngle)
rotZ[1, 0] = math.sin(zRotAngle)
rotZ[1, 1] = math.cos(zRotAngle)

scalingX = random.uniform(0.85, 1.15)
scalingY = random.uniform(0.85, 1.15)
scalingZ = random.uniform(0.85, 1.15)
scaling[0, 0] = scalingX
scaling[1, 1] = scalingY
scaling[2, 2] = scalingZ

transMat = translation @ rotZ @ rotY @ rotX @ scaling

affine_param_test = transMat[0:3].flatten()
affine_param_test = np.expand_dims(affine_param_test, axis=0)

transformedImg = model([inputImg, affine_param_test])
transformedImg = tf.squeeze(transformedImg)

In [None]:
sliceToCheck_x = 25
sliceToCheck_y = 100
sliceToCheck_z = 100
fig, axs = plt.subplots(3, 2, figsize=(15, 25))
axs[0, 0].imshow(fixedImg[sliceToCheck_x, :, :])
axs[0, 1].imshow(transformedImg[sliceToCheck_x, :, :])
axs[1, 0].imshow(fixedImg[:, sliceToCheck_y, :])
axs[1, 1].imshow(transformedImg[:, sliceToCheck_y, :])
axs[2, 0].imshow(fixedImg[:, :, sliceToCheck_z])
axs[2, 1].imshow(transformedImg[:, :, sliceToCheck_z])