## Import the necessary libraries

In [1]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
import tqdm
import scipy
import os


# specify the gpu to be utilized
os.environ["CUDA_VISIBLE_DEVICES"]="2"

2023-02-07 15:38:54.082369: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-07 15:38:54.218380: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
"""
The following lines of code can be used to identify the GPUs on the server and list their available memory

"""
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices("GPU")
        print("Detected {} Physical GPUs, {} Logical GPUs.".format(len(gpus), len(logical_gpus)))
        for i, gpu in enumerate(logical_gpus):
            # Obtain the memory information for the GPU
            memory = tf.compat.v1.Session().run(tf.compat.v1.contrib.memory_stats.BytesInUse(), feed_dict={tf.compat.v1.contrib.memory_stats.gpu_device_index(i): True})
            memory_info = memory / (1024 * 1024) # convert bytes to MB
            print("GPU {} ({}): Free memory: {:.2f} MB, Total memory: {:.2f} MB".format(i, gpu.name, memory_info, memory_info))
    except Exception as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)
else:
    print("No GPUs found.")



Detected 1 Physical GPUs, 1 Logical GPUs.
module 'tensorflow.compat.v1' has no attribute 'contrib'


2023-02-07 15:38:57.935633: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-07 15:38:58.751071: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 47205 MB memory:  -> device: 0, name: Quadro RTX 8000, pci bus id: 0000:af:00.0, compute capability: 7.5
2023-02-07 15:38:58.768125: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 47205 MB memory:  -> device: 0, name: Quadro RTX 8000, pci bus id: 0000:af:00.0, compute capability: 7.5


## define the data and model parameters

In [3]:
# image size
height_crop_size = 160 
width_crop_size = 64
slice_size = 64

# model training parameters
batch_size = 1
epochs = 20
epoch_decay = epochs // 10 # number of epoch to start decaying the learning rate

# optimizer parameters
lr = 0.0002 # learning rate
beta_1 = 0.5 # weight

## Load the training data

The following code defines a TensorFlow program that loads and processes a training dataset in the TensorFlow Record format.

1. The ``center_crop`` function is used to crop an image around the center to a specified size. The function takes an image tensor and the desired size as input, and returns the cropped image tensor.

2. The ``parser`` function is used to parse a single TensorFlow Record from the dataset. The function takes a single TensorFlow Record as input and returns the decoded anatomy and flow tensors.

3. The ``fldr`` variable defines the folder where the TensorFlow Record files are located. The tfrecord_paths variable is a list of the TensorFlow Record file paths in the folder.

4. The ``dataset_train`` variable is a TensorFlow Dataset that contains the TensorFlow Record files from the fldr folder. The dataset_train is processed using interleave and map functions to efficiently load and parse the TensorFlow Record files.

5. The ``data_size`` variable is the total size of the training dataset. The dataset is then ``shuffled`` and ``batched`` with a specified batch size. The ``len_dataset`` variable is the number of batches in the dataset.

In [4]:
# crop the enlarged images back to original size around the center
# This line of code is defining the center_crop function which takes two inputs, an image and a size. 
@tf.function
def center_crop(image, size):
    # This line of code checks if the size input is a tuple or a list. 
    # If it's not, the size is set to a list containing the size value twice. 
    if not isinstance(size, (tuple, list)):
        size = [size, size]
    # This line of code calculates the offset height value. 
    # The offset height is the difference between the height of the image and the desired height size, divided by 2.
    offset_height = (tf.shape(image)[-3]-size[0])//2
    # This line of code calculates the offset width value. 
    # The offset width is the difference between the width of the image and the desired width size, divided by 2.
    offset_width = (tf.shape(image)[-2]-size[1])//2
    # This line of code crops the image based on the calculated offset values and desired size values. 
    return tf.image.crop_to_bounding_box(image,offset_height,offset_width,size[0],size[1])


# Parse a single example from a tfrecord file
def parser(tfrecord):
    # Parse features from the tfrecord file
    feature = tf.io.parse_single_example(tfrecord,
                                          {'anatomy': tf.io.FixedLenFeature(shape=[], dtype=tf.string),
                                           'flow'  : tf.io.FixedLenFeature(shape=[], dtype=tf.string),
                                           'height': tf.io.FixedLenFeature(shape=[], dtype=tf.int64),
                                           'width' : tf.io.FixedLenFeature(shape=[], dtype=tf.int64),
                                           'depth' : tf.io.FixedLenFeature(shape=[], dtype=tf.int64)})
    # Convert height, width and depth from int64 to int32
    height = tf.cast(feature["height"], tf.int32)
    width  = tf.cast(feature["width"], tf.int32)
    depth  = tf.cast(feature["depth"], tf.int32)
    
    # Decode the anatomy feature from raw bytes to float32
    anatomy = tf.io.decode_raw(feature['anatomy'], tf.float32) 
    # Reshape the anatomy data to [height, width, depth]
    anatomy = tf.reshape(anatomy, [height, width, depth])
    # Crop the anatomy image
    anatomy = center_crop(anatomy, [height_crop_size,width_crop_size])
    
    # Decode the flow feature from raw bytes to float32
    flow = tf.io.decode_raw(feature['flow'], tf.float32) 
    # Reshape the flow data to [height, width, depth]
    flow = tf.reshape(flow, [height, width, depth])
    # Crop the flow image
    flow = center_crop(flow, [height_crop_size,width_crop_size])
    
    # Return the anatomy and flow images as a tuple
    return anatomy, flow


# set the directory to store the training data
fldr = 'training_data'

# get the path of all training data files in the directory
tfrecord_paths = tf.io.gfile.glob(fldr+"/*train*")

# print the number of files in the training dataset
print(f"The training dataset is composed of {len(tfrecord_paths)} files:\n{tfrecord_paths}")

# create a dataset from the training data file paths
dataset_train = tf.data.Dataset.list_files(tfrecord_paths)

# interleave multiple training data files for parallel reading
dataset_train = dataset_train.interleave(lambda filename: tf.data.TFRecordDataset(filename),
                                         cycle_length=len(tfrecord_paths),
                                         num_parallel_calls=tf.data.AUTOTUNE)

# apply the parsing function to each data sample
dataset_train = dataset_train.map(map_func=parser, num_parallel_calls=tf.data.AUTOTUNE)

# get the number of samples in the dataset
data_size = sum(1 for _ in dataset_train)

# shuffle the samples
dataset_train = dataset_train.shuffle(buffer_size=data_size)

# group the samples into batches
dataset_train = dataset_train.batch(batch_size)

# calculate the number of batches
len_dataset = int(data_size/batch_size)

# print the size of the loaded training dataset
print(f"Size of the loaded training data set: {len_dataset} batches")

The training dataset is composed of 3 files:
['training_data/bav_train.tfrecords', 'training_data/tav_train.tfrecords', 'training_data/ct2mri_train.tfrecords']
Size of the loaded training data set: 1095 batches


## Generator and discriminator models

1. **Generator**
- **generator** is the 3D generator model for a convolutional neural network (CNN) using the functions encoder_layer and decoder_layer. The model has an input layer with a specified shape and applies multiple encoder and decoder layers to generate an output.
    - The function ``encoder_layer`` implements an encoder layer for a 3D Convolutional Neural Network (3D-CNN) model. It takes as input a feature map "x_con" and applies multiple 3D Convolution, Batch Normalization, and Leaky ReLU activation operations to produce a new feature map. The new feature map is concatenated with the input feature map, and this concatenation is fed as input to the next iteration of 3D Convolution, Batch Normalization, and Leaky ReLU. A pooling operation is optionally applied at the end of the encoding layer, which is specified by the ``pool`` argument. The pooling operation is an Average Pooling operation that reduces the spatial dimensions of the feature map by a factor of 2.

    - The ``decoder layer`` defines a generator decoder layer. The layer takes three inputs, input_, x, and ch, which represent the tensor to be upsampled, a tensor to concatenate with the upsampled tensor, and the number of channels, respectively. The layer performs a transposed convolution with filters equal to 20, kernel_size equal to [2,2,1], and strides equal to [2,2,1] to upsample input_. The upsampled tensor and x are then concatenated along the last axis.

2. **Discriminator**
- The ``downsample`` creates a 3D convolutional neural network (CNN) layer followed by batch normalization and leaky rectified linear unit (ReLU) activation is is used in the discriminator.

In [5]:
def encoder_layer(x_con, iterations, name, training, pool=True, filters=20, kernel_size=(3, 3, 3)):
    with tf.name_scope("encoder_block_{}".format(name)):
        # create batch normalization layer
        bn = tf.keras.layers.BatchNormalization()
        # create leaky relu layer
        relu = tf.keras.layers.LeakyReLU()
        # loop through the number of iterations
        for i in range(iterations):
            # apply 3D convolution layer with specified filters and kernel size, with padding "SAME"
            x = tf.keras.layers.Conv3D(filters, 
                                       kernel_size, 
                                       padding='SAME')(x_con)
            # apply batch normalization to the output
            x = bn(x)
            # apply leaky relu activation to the output
            x = relu(x)
            # concatenate x and x_con along the last axis
            x_con = tf.concat([x, x_con], axis=-1)
        # if pool is True
        if pool:
            # apply average pooling with specified pool size and strides, and data format 'channels_last'
            pool = tf.keras.layers.AveragePooling3D(pool_size=(2, 2, 1), 
                                                    strides=(2, 2, 1),
                                                    data_format='channels_last')(x_con)
            # return both x_con and pool
            return x_con, pool
        # if pool is False
        return x_con
        # return only x_con


    
def decoder_layer(inputs, x, channels, name, upscale=(2, 2, 1), filters=20, kernel_size=(2, 2, 1)):
    # Upsample the input with a transposed convolution
    up = tf.keras.layers.Conv3DTranspose(filters=filters,
                                         kernel_size=kernel_size,
                                         strides=upscale,
                                         padding='SAME',
                                         name=f"upsample_{name}",
                                         use_bias=False)(inputs)
    # Concatenate the upsampled input with the other input `x`
    up = tf.concat([up, x], axis=-1, name=f"merge_{name}")
    return up

def generator():
    """
    Function to define the generator model of the CNN
    
    Returns:
    model (tf.keras.Model): The generator model
    """
    # Input layer with specified shape
    input_ = tf.keras.layers.Input(shape=[height_crop_size,width_crop_size,slice_size,1])
    
    # Encoder layers with specified iterations and name
    conv1, pool1 = encoder_layer(input_, iterations=2, name="encode_im1", training=True, pool=True)
    conv2, pool2 = encoder_layer(pool1, iterations=4, name="encode_im2", training=True, pool=True)
    conv3, pool3 = encoder_layer(pool2, iterations=6, name="encode_im3", training=True, pool=True)
    conv4 = encoder_layer(pool3, iterations=8, name="encode_im4", training=True, pool=False)
    
    # Decoder layers with specified name
    up1 = decoder_layer(conv4, conv3, 10, name=12)
    conv7 = encoder_layer(up1, iterations=6, name="conv_im6", training=True, pool=False)
    up2 = decoder_layer(conv7, conv2, 8, name=21)
    conv8 = encoder_layer(up2, iterations=4, name="encode_im7", training=True, pool=False)
    up3 = decoder_layer(conv8, conv1, 6, name=32)
    conv9 = encoder_layer(up3, iterations=2, name="encode_im8", training=True, pool=False)
    
    # Final Conv3D layer
    conv10 = tf.keras.layers.Conv3D(1, (1,1,1), name='logits_re_im', padding='SAME')(conv9)
    
    # Model definition
    model = tf.keras.Model(inputs=input_, outputs=conv10)
    
    return model


# disciminator downsampler
def downsample(filters, size, apply_batchnorm=True):
    """
    Function to create a downsampling layer in the generator network
    
    Parameters:
    filters (int): Number of filters in the Conv3D layer
    size (int): The size of the Conv3D layer
    apply_batchnorm (bool, optional): If True, adds BatchNormalization layer after Conv3D. Defaults to True.
    
    Returns:
    result (tf.keras.Sequential): The downsampling layer
    """
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential()
    
    # Conv3D layer with specified number of filters and kernel size
    result.add(tf.keras.layers.Conv3D(filters, kernel_size=[3,3,3], padding='same',
                                      kernel_initializer=initializer))
    
    # Optional BatchNormalization layer and leaky ReLU activation
    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())
        result.add(tf.keras.layers.LeakyReLU())
    
    return result

def discriminator():
    # Initialize a random normal initializer with mean 0 and standard deviation 0.02
    initializer = tf.random_normal_initializer(0., 0.02)
    
    # Create an input layer for the input image with shape [height_crop_size, width_crop_size, slice_size, 1]
    inp = tf.keras.layers.Input(shape=[height_crop_size, width_crop_size, slice_size, 1],
                                name='input_image')
    
    # Create an input layer for the target image with shape [height_crop_size, width_crop_size, slice_size, 1]
    tar = tf.keras.layers.Input(shape=[height_crop_size, width_crop_size, slice_size, 1],
                                name='target_image')
    
    # Concatenate the input and target images
    x = tf.keras.layers.concatenate([inp, tar])
    
    # Apply the downsample function with 16 filters and a kernel size of 2, and no batch normalization
    down1 = downsample(16, 2, False)(x)
    
    # Apply the downsample function with 32 filters and a kernel size of 2, and batch normalization
    down2 = downsample(32, 2)(down1)
    
    # Apply the downsample function with 64 filters and a kernel size of 2, and batch normalization
    down3 = downsample(64, 2)(down2)
    
    # Apply a zero padding layer to down3
    zero_pad1 = tf.keras.layers.ZeroPadding3D()(down3)
    
    # Apply a Conv3D layer with 128 filters and a kernel size of 4, and stride of 1
    conv = tf.keras.layers.Conv3D(128, 4, strides=1, padding='SAME')(zero_pad1)
    
    # Apply a batch normalization layer to the output of the Conv3D layer
    batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
    
    # Apply a leaky ReLU activation layer to the output of the batch normalization layer
    leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
    
    # Apply a zero padding layer to the output of the leaky ReLU activation layer
    zero_pad2 = tf.keras.layers.ZeroPadding3D()(leaky_relu)
    
    # Apply a Conv3D layer with 1 filter and a kernel size of 4, and stride of 1
    last = tf.keras.layers.Conv3D(1, 4, strides=1, padding='SAME')(zero_pad2)
    
    # Return a Model with input layers [inp, tar] and output layer `last`
    return tf.keras.Model(inputs=[inp, tar], outputs=last)

## Training step

In [6]:
def train_step(disc, gen, anatomy, flow, gen_optimizer, disc_optimizer):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:  
        # Normalize anatomy data between 0 and 1
        anatomy = tf.keras.utils.normalize(anatomy)

        # Generate anatomy2flow
        anatomy2flow = gen(anatomy) 

        # Calculate magnitude of anatomy2flow
        mag = tf.abs(tf.squeeze(anatomy2flow))

        # Loss functions
        mse = tf.keras.losses.MeanSquaredError()
        loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

        # Evaluate discriminator on real data and generated data
        discriminator_real_output = disc([anatomy, flow], training=True)
        discriminator_generated_output = disc([anatomy, anatomy2flow], training=True)

        # Calculate real loss and generated loss for discriminator
        real_loss = loss_object(tf.ones_like(discriminator_real_output), discriminator_real_output)
        generated_loss = loss_object(tf.zeros_like(discriminator_generated_output), discriminator_generated_output)
        total_discriminator_loss = real_loss + generated_loss

        # Calculate GAN loss
        gan_loss = loss_object(tf.ones_like(discriminator_generated_output), discriminator_generated_output)

        # Calculate gradients and laplacian for both magnitude and flow
        flow = tf.squeeze(flow)
        ai_grady = np.gradient(mag[...].numpy())
        man_grady = np.gradient(flow[...].numpy())
        ai_grady2 = scipy.ndimage.laplace(mag.numpy())
        man_grady2 = scipy.ndimage.laplace(flow.numpy())

        # Calculate structural similarity (SSIM)
        loss_ssimy = tf.reduce_mean(tf.image.ssim(mag[...]/tf.reduce_max(mag[...]), \
                                                 flow[...]/tf.reduce_max(flow[...]),1.0))

        # Calculate final loss function
        loss_fn = gan_loss + 1000 * mse(mag, flow) + (1-loss_ssimy) + 1000 * mse(ai_grady, man_grady) + 10 * mse(ai_grady2, man_grady2)
    
    # Trainable variables for discriminator and generator
    variables1 = disc.trainable_variables     
    variables = gen.trainable_variables 
    
    # Calculate gradients for generator and discriminator
    gradients = gen_tape.gradient(loss_fn, variables)
    gradients1 = disc_tape.gradient(total_discriminator_loss, variables1)
    
    # Apply gradients using optimizers
    gen_optimizer.apply_gradients(zip(gradients,variables)) 
    disc_optimizer.apply_gradients(zip(gradients1,variables1))
    return loss_fn, loss_ssimy , gan_loss

# define LinearDecay schedular
class LinearDecay(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_learning_rate, total_steps, step_decay):
        super(LinearDecay, self).__init__()
        self._initial_learning_rate = initial_learning_rate
        self._steps = total_steps
        self._step_decay = step_decay
        self.current_learning_rate = tf.Variable(initial_value=initial_learning_rate,\
                                                 trainable=False, dtype=tf.float32)
    def __call__(self, step):
        self.current_learning_rate.assign(tf.cond(
            step >= self._step_decay,
            true_fn=lambda: self._initial_learning_rate * (1 - 1 /  (self._steps - self._step_decay) \
                                                           *(step - self._step_decay)),
            false_fn=lambda: self._initial_learning_rate
        ))
        return self.current_learning_rate

# define the generator and discriminator optimizers
llr = LinearDecay(lr, epochs*len_dataset, epoch_decay*len_dataset)
gen_optimizer = tf.keras.optimizers.Adam(learning_rate=llr) 
disc_optimizer = tf.keras.optimizers.Adam(lr, beta_1=beta_1)

## Train the model
This code is checking if the output directory and checkpoint directory exists, if not it creates them. Then it defines the generator and discriminator models, the checkpoint prefix, and checkpoint object. Finally, it checks if a checkpoint exists, and if so, it restores the model from the latest checkpoint by using the ``tf.train.latest_checkpoint`` method. The number of trained epochs is also found by parsing the checkpoint address.

In [None]:
# Define the output directory
output_dir = './haben_output'

# If the output directory does not exist, create it
if not tf.io.gfile.isdir(output_dir):
    tf.io.gfile.mkdir(output_dir)

# Define the generator and discriminator models
gen = generator()
disc = discriminator()

# Define the checkpoint directory
checkpoint_dir = tf.io.gfile.join(output_dir, 'training_checkpoints')

# If the checkpoint directory does not exist, create it
if not tf.io.gfile.isdir(checkpoint_dir):
    tf.io.gfile.mkdir(checkpoint_dir)

# Define the checkpoint prefix and the checkpoint object
checkpoint_prefix = tf.io.gfile.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(gen_optimizer=gen_optimizer,
                                 disc_optimizer=disc_optimizer,
                                 gen=gen,
                                 disc=disc)

# Get the address of the latest checkpoint
checkpoint_address = tf.train.latest_checkpoint(checkpoint_dir)

# If a checkpoint exists, restore the model from it
epochs_so_far = 0
if checkpoint_address:
    checkpoint.restore(checkpoint_address)
    hyphen = checkpoint_address.index('-')
    epochs_so_far = int(checkpoint_address[hyphen+1:])
    print("Restored the model from epoch {}".format(epochs_so_far))


# Train the model for each outer epoch
for epoch in tqdm.trange(epochs_so_far+1,epochs+1,desc="Outer Epoch",total=epochs-epochs_so_far):
    # Train the model for each inner epoch
    for anatomy,flow in tqdm.tqdm(dataset_train,desc="Inner Epoch",total=len_dataset):
        # Reshape the anatomy and flow data to a specific shape
        anatomy = tf.reshape(anatomy,(1,height_crop_size,width_crop_size,slice_size,1))
        flow = tf.reshape(flow,(1,height_crop_size,width_crop_size,slice_size,1))
        # Train the generator and discriminator models
        loss, sss, gn = train_step(disc,gen,anatomy,flow, gen_optimizer, disc_optimizer)
    # Save the model checkpoint after each outer epoch
    checkpoint.save(file_prefix=checkpoint_prefix)

Outer Epoch:   0%|                                                                                                                                   | 0/20 [00:00<?, ?it/s]
Inner Epoch:   0%|                                                                                                                                 | 0/1095 [00:00<?, ?it/s][A2023-02-07 15:39:19.526280: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:390] Filling up shuffle buffer (this may take a while): 904 of 1095
2023-02-07 15:39:22.676156: I tensorflow/core/kernels/data/shuffle_dataset_op.cc:415] Shuffle buffer filled.
2023-02-07 15:39:23.534419: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8401
2023-02-07 15:39:25.289787: W tensorflow/stream_executor/gpu/asm_compiler.cc:80] Couldn't get ptxas version string: INTERNAL: Running ptxas --version returned 32512
2023-02-07 15:39:25.693552: W tensorflow/stream_executor/gpu/redzone_allocator.cc:314] INTERNAL: ptxas exited with non-zero 