In [1]:
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
import scipy

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint

import random

In [2]:
import layers as custom_layers

In [3]:
# We renamed cusolver64_11.dll to cusolver64_10.dll to solve the compatibility issue.
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


In [4]:
# File Paths
imgIdxCsvPath = './MRNet/MRNet-v1.0/similar.csv'
MRI_Path = './MRNet/MRNet-v1.0/train/axial/{}.npy'

## Load self-generated training data (by data loaders)

### Duplicate fixed image as labels

In [5]:
bs = 1

In [6]:
trainDataPath = "./affineTrainingData/affine{}.npz"
trainDataSize = 200
fixedImg = np.load(MRI_Path.format("0701"))
fixedImg = fixedImg / np.max(fixedImg)
fixedImg = fixedImg.astype('float32')
fixedImg = np.expand_dims(fixedImg, axis=-1)

def img_pair_input_gen(batch_size = bs):
    while True:
        imgPair_batch = np.zeros((batch_size, *fixedImg.shape[:-1], 2))
        # fixedImg_batch = np.zeros((batch_size, *fixedImg.shape))
        # tgtAffineTrf_batch = np.zeros((batch_size, 12))
        for i in range(batch_size):
            idx = random.randrange(trainDataSize)
            inputObj = np.load(trainDataPath.format(idx))
            movingImg = inputObj['img']
            movingImg = np.expand_dims(movingImg, axis=-1)
            movingImg = movingImg.astype('float32')
            imgPair = np.concatenate([movingImg, fixedImg], axis=3)
            imgPair_batch[i] = imgPair

            # tgtAffineTrf = inputObj['trf']
            # tgtAffineTrf = tgtAffineTrf.astype('float32')
            # tgtAffineTrf_batch[i] = tgtAffineTrf
            # fixedImg_batch[i] = fixedImg
        
        yield imgPair_batch


def gt_fixed_img_input_gen(batch_size = bs):
    while True:
        fixedImg_batch = np.zeros((batch_size, *fixedImg.shape))
        for i in range(batch_size):
            fixedImg_batch[i] = fixedImg
        
        yield fixedImg_batch

In [7]:
img_pair_input = img_pair_input_gen()
gt_fixed_img_input = gt_fixed_img_input_gen()

## Define the generator

In [9]:
def make_generator(lrelu_alpha=0.3):
    def conv_activation(inputs):
        outputs = layers.BatchNormalization()(inputs)
        outputs = layers.LeakyReLU(alpha=lrelu_alpha)(outputs)
        return outputs

    inputs = keras.Input(shape = (*fixedImg.shape[:-1], 2))
    moving_input = tf.expand_dims(inputs[:, :, :, :, 0], axis = -1)
    # fixed_input = tf.expand_dims(inputs[:, :, :, :, 1], axis = -1)

    down_depths = [32, 64, 128, 512]
    up_depths = [512, 128, 64, 32, 3]

    # Down-sampling path
    x = layers.Conv3D(filters=down_depths[0], kernel_size=(7, 15, 15), strides=(1, 1, 1), padding="same", activation=layers.LeakyReLU())(inputs)
    x = layers.MaxPool3D((1, 2, 2), padding="same")(x)
    print("0: {}".format(x.shape))

    x = layers.Conv3D(filters=down_depths[1], kernel_size=(3, 3, 3), strides=(1, 2, 2), padding="same")(x)
    x = conv_activation(x)
    x = layers.Conv3D(filters=down_depths[1], kernel_size=(3, 3, 3), strides=(1, 1, 1), padding="same")(x)
    x = conv_activation(x)
    print("1: {}".format(x.shape))

    x = layers.Conv3D(filters=down_depths[2], kernel_size=(3, 3, 3), strides=(2, 2, 2), padding="same")(x)
    x = conv_activation(x)
    x = layers.Conv3D(filters=down_depths[2], kernel_size=(3, 3, 3), strides=(1, 1, 1), padding="same")(x)
    x = conv_activation(x)
    print("2: {}".format(x.shape))

    x = layers.Conv3D(filters=down_depths[3], kernel_size=(3, 3, 3), strides=(2, 2, 2), padding="same")(x)
    x = conv_activation(x)
    x = layers.Conv3D(filters=down_depths[3], kernel_size=(3, 3, 3), strides=(1, 1, 1), padding="same")(x)
    x = conv_activation(x)
    print("3: {}".format(x.shape))


    # Up-sampling path
    x = layers.Conv3DTranspose(filters=up_depths[0], kernel_size=(3, 3, 3), strides=(2, 2, 2), padding="same")(x)
    x = conv_activation(x)
    x = layers.Conv3D(filters=up_depths[0], kernel_size=(3, 3, 3), strides=(1, 1, 1), padding="same")(x)
    print("4: {}".format(x.shape))

    x = layers.Conv3DTranspose(filters=up_depths[1], kernel_size=(3, 3, 3), strides=(2, 2, 2), padding="same")(x)
    x = conv_activation(x)
    x = layers.Conv3D(filters=up_depths[1], kernel_size=(3, 3, 3), strides=(1, 1, 1), padding="same")(x)
    print("5: {}".format(x.shape))

    x = layers.Conv3DTranspose(filters=up_depths[2], kernel_size=(3, 3, 3), strides=(1, 2, 2), padding="same")(x)
    x = conv_activation(x)
    x = layers.Conv3D(filters=up_depths[2], kernel_size=(3, 3, 3), strides=(1, 1, 1), padding="same")(x)
    print("6: {}".format(x.shape))


    x = layers.Conv3DTranspose(filters=up_depths[3], kernel_size=(3, 3, 3), strides=(1, 2, 2), padding="same")(x)
    x = conv_activation(x)
    x = layers.Conv3D(filters=up_depths[3], kernel_size=(3, 3, 3), strides=(1, 1, 1), padding="same")(x)
    print("7: {}".format(x.shape))

    deformation_field_pred = layers.Conv3D(filters=up_depths[4], kernel_size=(3, 3, 3), strides=(1, 1, 1), padding="same")(x)
    print("8: {}".format(x.shape))

    # affine_warped = custom_layers.SpatialTransformer(interp_method='linear', add_identity=False, name="warped_image", shift_center=True)([moving_input, affine_pred])
    deformable_warped = custom_layers.SpatialTransformer(interp_method='linear', add_identity=False, name="warped_image", shift_center=True)([moving_input, deformation_field_pred])

    model = keras.Model(inputs=inputs, outputs=deformable_warped, name="generator_model")
#     model.summary()

    return model

In [10]:
generator = make_generator()

0: (None, 52, 128, 128, 32)
1: (None, 52, 64, 64, 64)
2: (None, 26, 32, 32, 128)
3: (None, 13, 16, 16, 512)
4: (None, 26, 32, 32, 512)
5: (None, 52, 64, 64, 128)
6: (None, 52, 128, 128, 64)
7: (None, 52, 256, 256, 32)
8: (None, 52, 256, 256, 32)
Instructions for updating:
Use fn_output_signature instead


## Define the discriminator

In [11]:
def make_discriminator(lrelu_alpha=0.3, drop_out_rate=0.2):
    def conv_activation(inputs):
        outputs = layers.BatchNormalization()(inputs)
        outputs = layers.LeakyReLU(alpha=lrelu_alpha)(outputs)
        outputs = layers.Dropout(drop_out_rate)(outputs)
        return outputs
    
    inputs = keras.Input(shape = fixedImg.shape)

    x = layers.Conv3D(filters=32, kernel_size=(7, 15, 15), strides=(1, 2, 2), padding="same")(inputs)
    x = conv_activation(x)

    x = layers.Conv3D(filters=64, kernel_size=(5, 5, 5), strides=(1, 2, 2), padding="same")(x)
    x = conv_activation(x)

    x = layers.Conv3D(filters=128, kernel_size=(5, 5, 5), strides=(2, 2, 2), padding="same")(x)
    x = conv_activation(x)

    x = layers.Conv3D(filters=256, kernel_size=(5, 5, 5), strides=(2, 2, 2), padding="same")(x)
    x = conv_activation(x)

    x = layers.Conv3D(filters=512, kernel_size=(5, 5, 5), strides=(2, 2, 2), padding="same")(x)
    x = conv_activation(x)

    x = layers.AveragePooling3D((2, 2, 2))(x)
    print(x.shape)
    x = layers.Flatten()(x)
    x = layers.Dense(128)(x)
    x = conv_activation(x)
    x = layers.Dense(32)(x)
    x = conv_activation(x)
    pred = layers.Dense(1)(x)
    
    model = keras.Model(inputs=inputs, outputs=pred, name="discriminator_model")
#     model.summary()

    return model
    

In [12]:
discriminator = make_discriminator()

(None, 3, 4, 4, 512)


In [13]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [14]:
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

In [15]:
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [16]:
lr = 1e-4
generator_optimizer = tf.keras.optimizers.Adam(lr)
discriminator_optimizer = tf.keras.optimizers.Adam(lr)

In [17]:
checkpoint_dir = './GAN_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [18]:
@tf.function
def train_step(gt_fixed_images):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(next(img_pair_input), training=True)

      real_output = discriminator(gt_fixed_images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [19]:
def train(epochs, steps_per_epoch):
    for epoch in range(epochs):
        for step in range(steps_per_epoch):
            gt_fixed_img = next(gt_fixed_img_input)
            train_step(gt_fixed_img)

        # Save the model every 15 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
            
        print("Epoch {} finished".format(epoch))

In [20]:
epochs = 50
steps_per_epoch = int(trainDataSize / bs)
train(epochs, steps_per_epoch)

ResourceExhaustedError:  OOM when allocating tensor with shape[1,64,52,128,128] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[node generator_model/conv3d_9/Conv3D (defined at <ipython-input-18-ece6db4c675b>:4) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.
 [Op:__inference_train_step_15228]

Function call stack:
train_step


## Test model (output == warpedImg)

In [None]:
testModel = keras.Model(inputs=inputs, outputs=[affine_warped, affine_pred])

In [None]:
# testModel.load_weights('./checkpoints/epoch_{}'.format(epochs-1))
testModel.load_weights('./checkpoints/13.h5')

In [None]:
dataGen_test = data_generator(batch_size = 1)
moving_test, label_test = next(dataGen_test)
(warped_test, affine_pred_test) = testModel(moving_test)
print(label_test[1])
print(affine_pred_test)
mse = tf.keras.losses.MeanSquaredError()
print(mse(label_test[1], affine_pred_test))

In [None]:
sliceToCheck = 0
fig, axs = plt.subplots(5, 3, figsize=(15, 35))
for i in range(5):
    axs[i, 0].imshow(fixedImg[sliceToCheck + i * 5,:,:,0])
    axs[i, 0].set_title("Fixed Image")
    axs[i, 1].imshow(moving_test[0,sliceToCheck + i * 5,:,:,0])
    axs[i, 1].set_title("Moving Image")
    axs[i, 2].imshow(warped_test[0,sliceToCheck + i * 5,:,:,0])
    axs[i, 2].set_title("Warped Image")
plt.show()