In [None]:
colab = False  
install = True

In [None]:
# Mount Google Drive
%%capture
if colab:
    from google.colab import drive # import drive from google colab

    ROOT = "/content/drive"     # default location for the drive

    drive.mount(ROOT);           # we mount the google drive at /content/drive
    
    # Set working directory
    %cd /content/drive/My Drive/restoration-mapper/tree_gan

In [None]:
if install:
    !pip install tensorflow==1.13.1

In [None]:
if install:
    !pip install -r requirements.txt 

In [None]:
if install:
    !pip install array2gif

In [None]:
import tensorflow as tf
print(tf.__version__)

import os
config=tf.ConfigProto()
config.gpu_options.allow_growth=True
config.allow_soft_placement=True

import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt
import numpy as np
import time

#to make directories
import pathlib

import sys

from utils import *

import argparse

# P: allows for easy module reloads
from importlib import reload

# P
import warnings
warnings.filterwarnings('ignore')

# P
import math

In [None]:
class params:
  dataset = 'acdc'
  #no of training images
  no_of_tr_imgs = 'tr3' # Options include: ['tr1', 'tr3', 'tr5', 'tr15', 'tr40']
  #combination of training images
  comb_tr_imgs = 'c1' # Options include: ['c1', 'c2', 'c3', 'c4', 'c5']

  #learning rate of seg unet
  lr_seg = 0.00001
  # learning rate of generator
  lr_gen = 0.0001
  # learning rate of discriminator
  lr_disc = 0.0001
  # lat dim of z sample
  z_lat_dim = 100

  # ra_en : 0 - disabled, 1 - enabled
  ra_en = 0
  # select gan type
  gan_type = 'lsgan' # Options include: ['lsgan', 'gan', 'wgan-gp','ngan']
  # beta value of Adam optimizer
  beta_val = 0.9
  # to enable the representation of labels with 1 hot encoding
  en_1hot = 1

  # lamda factors
  # for segmenation loss term (lamda_dsc)
  lamda_dsc = 1
  # adversarial loss term (lamda_adv)
  lamda_adv = 1
  ### deformation field cGAN specific
  # for negative L1 loss on spatial transformation (per-pixel flow field/deformation field) term (lamda_l1_g)
  lamda_l1_g = 0.001

  ### Intensity field cGAN specific
  # for negative L1 loss on transformation (additive intensity field) term (lamda_l1_i)
  lamda_l1_i = 0.001

  #version of run
  ver = 0

  #data aug - 0 - disabled, 1 - enabled
  data_aug_seg = 1 # Options include: [0,1]

  # segmentation loss to optimize
  # 0 for weighted cross entropy, 1 for dice score loss
  dsc_loss = 0 # Options include: [0,1]

## Deformation Network

In [None]:
ra_en_val=params.ra_en
if(params.ra_en==1):
    params.ra_en=True
else:
    params.ra_en=False

import experiment_init.init_acdc as cfg
import experiment_init.data_cfg_acdc as data_list

In [None]:
######################################
# class loaders
# ####################################
#  load dataloader object
from dataloaders import dataloaderObj
dt = dataloaderObj(cfg)


#print('set acdc orig img dataloader handle')
orig_img_dt=dt.load_acdc_imgs

#  load model object
import models 
model = models.modelObj(cfg)
#  load f1_utils object
from f1_utils import f1_utilsObj
f1_util = f1_utilsObj(cfg,dt)

In [None]:
######################################
#define save_dir for the model
save_dir = 'models/'
if not os.path.exists(save_dir[:-1]):
    os.makedirs(save_dir[:-1])
print('save_dir: ',save_dir)
######################################

In [None]:
######################################
# load train and val images
#train_list = data_list.train_data(params.no_of_tr_imgs,params.comb_tr_imgs)
#load train data cropped images directly
print('loading train imgs')
train_imgs,train_labels = dt.load_imgs(dataset= 'lab')
# P: switching dimensions before feeding into minibatch function
train_imgs = np.moveaxis(train_imgs,3,0)
train_labels = np.moveaxis(train_labels,2,0)
print(train_imgs.shape)
print(train_labels.shape)
#D: we might have to adjust this if our mini sample is too small compared to batch size
# if(params.no_of_tr_imgs=='tr1'):
#     train_imgs_copy=np.copy(train_imgs)
#     train_labels_copy=np.copy(train_labels)
#     while(train_imgs.shape[2]<cfg.batch_size):
#         train_imgs=np.concatenate((train_imgs,train_imgs_copy),axis=2)
#         train_labels=np.concatenate((train_labels,train_labels_copy),axis=2)
#     del train_imgs_copy,train_labels_copy

#load both val data and its cropped images
print('loading val imgs')
val_imgs,val_labels = dt.load_imgs(dataset= 'val')
val_n = val_imgs.shape[3]
#L: change to (20,16,16,16)
val_imgs = np.moveaxis(val_imgs,3,0)
val_labels = np.moveaxis(val_labels,2,0)
print('val image dims after reshape')
print(val_imgs.shape)
print(val_labels.shape)

# # load unlabeled images
#unl_list = data_list.unlabeled_data()
print('loading unlabeled imgs')
unlabeled_imgs=dt.load_imgs(dataset= 'unlab')
# P: switching dimensions before feeding into minibatch function
unlabeled_imgs = np.moveaxis(unlabeled_imgs,3,0)
print('unlabeled_imgs',unlabeled_imgs.shape)


# get test list
#print('get test imgs list')
#D: will have to add this back once test data created
#test_list = data_list.test_data()
#D: will have to figure out struct name in our case - it is used for computing the 
# model performance, segmentation mask etc. in f1_util.pred_segs_acdc_test_subjs
#struct_name=cfg.struct_name
val_step_update=cfg.val_step_update
######################################

In [None]:
######################################

def get_samples(labeled_imgs,unlabeled_imgs):
    # sample z vectors from Gaussian Distribution
    z_samples = np.random.normal(loc=0.0, scale=1.0, size=(cfg.batch_size, params.z_lat_dim)).astype(np.float32)

    # sample Unlabeled data shuffled batch
    unld_img_batch=shuffle_minibatch([unlabeled_imgs],batch_size=int(cfg.batch_size),num_channels=cfg.num_channels,labels_present=0,axis=2)

    # sample Labelled data shuffled batch
    ld_img_batch=shuffle_minibatch([labeled_imgs],batch_size=int(cfg.batch_size),num_channels=cfg.num_channels,labels_present=0,axis=2)

    return z_samples,ld_img_batch,unld_img_batch

def plt_func(sess,ae,save_dir,z_samples,ld_img_batch,unld_img_batch,index=0):
    # plot deformed images for an fixed input image and different per-pixel flow vectors generated from sampled z values
    ld_img_tmp=np.zeros_like(ld_img_batch)
    # select one 2D image from the batch and apply different z's sampled over this selected image
    for i in range(0,20):
        ld_img_tmp[i,:,:,0]=ld_img_batch[index,:,:,0]

    flow_vec,y_geo_deformed,z_cost=sess.run([ae['flow_vec'],ae['y_trans'],ae['z_cost']], feed_dict={ae['x_l']: ld_img_tmp, ae['z']:z_samples,\
                          ae['x_unl']: unld_img_batch, ae['select_mask']: True, ae['train_phase']: False})

    f1_util.plot_deformed_imgs(ld_img_tmp,y_geo_deformed,flow_vec,save_dir,index=index)

    # Plot gif of all the deformed images generated for the fixed input image
    f1_util.write_gif_func(ip_img=y_geo_deformed, imsize=(cfg.img_size_x,cfg.img_size_y),save_dir=save_dir,index=index)

In [None]:
######################################
# Define checkpoint file to save CNN architecture and learnt hyperparameters
checkpoint_filename='unet_'+str(params.dataset)
logs_path = str(save_dir)+'tensorflow_logs/'
best_model_dir=str(save_dir)+'best_model/'
######################################

In [None]:
#L: set class weights so adapts to training data if it changes
ntrlabs = np.sum(np.unique(train_labels, return_counts = True)[1])
propnotree = (ntrlabs - np.unique(train_labels, return_counts = True)[1][1])/ntrlabs
proptree = 1-propnotree
classweight = tf.constant([[proptree, propnotree]],name='class_weights')

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [None]:
######################################
# Define deformation field generator model graph
ae = model.spatial_generator_cgan_unet(learn_rate_gen=params.lr_gen,learn_rate_disc=params.lr_disc,\
                        beta1_val=params.beta_val,gan_type=params.gan_type,ra_en=params.ra_en,\
                        learn_rate_seg=params.lr_seg,dsc_loss=params.dsc_loss,en_1hot=params.en_1hot,\
                        lamda_dsc=params.lamda_dsc,lamda_adv=params.lamda_adv,lamda_l1_g=params.lamda_l1_g,
                        class_weights = classweight, num_channels = cfg.num_channels)

In [None]:
######################################
#  training parameters
start_epoch=0
#L: make this 10 for quick training to test
n_epochs = 2000
disp_step=400
print_step=2000
# no of iterations to train just the segmentation network using the labeled data without any cGAN generated data
seg_tr_limit=500
f1_val_prev=0.1
threshold_f1=0.000001
pathlib.Path(best_model_dir).mkdir(parents=True, exist_ok=True)
######################################

In [None]:
######################################
# define graph to compute deformed image given an per-pixel flow vector and input image
#L: change this to new function, deform net clip
df_ae= model.deform_netclip()
######################################

In [None]:
######################################
#writer for train summary
train_writer = tf.summary.FileWriter(logs_path)
#writer for dice score and val summary
dsc_writer = tf.summary.FileWriter(logs_path)
val_sum_writer = tf.summary.FileWriter(logs_path)
######################################

######################################
# create a session and initialize variable to use the graph
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
# Save training data
saver = tf.train.Saver(max_to_keep=2)
######################################

In [None]:
# Run for n_epochs

# arrays to store metrics from every epoch
seg_cost_epoch = np.array([])
seg_acc_epoch = np.array([])
g_loss_epoch = np.array([])
d_loss_epoch = np.array([])
val_acc_epoch = np.array([])
val_f1_epoch = np.array([])


for epoch_i in range(start_epoch,n_epochs):
    

    
    # sample Unlabeled shuffled batch
    unld_img_batches=shuffle_minibatch([unlabeled_imgs],batch_size=int(cfg.batch_size),labels_present=0)
    # P: already did this above
#     unld_img_batch = np.moveaxis(unld_img_batch,3,0).reshape(20,16,16,16)
    
    # sample Labelled shuffled batch
    ld_img_batches,ld_label_batches=shuffle_minibatch([train_imgs,train_labels],batch_size=cfg.batch_size)
    # P: already did this above
#     ld_img_batch = np.moveaxis(ld_img_batch,3,0)
#     ld_label_batch = np.moveaxis(ld_label_batch,2,0)

    minibatches = math.floor(len(train_imgs)/cfg.batch_size+0.5)
    for b in range(minibatches):
        
        unld_img_batch = unld_img_batches[b]
        ld_img_batch = ld_img_batches[b]
        ld_label_batch = ld_label_batches[b]

        # sample z's from Gaussian Distribution
        z_samples = np.random.normal(loc=0.0, scale=1.0, size=(cfg.batch_size, params.z_lat_dim)).astype(np.float32)

        if(cfg.aug_en==1):
            # Apply affine transformations
            ld_img_batch,ld_label_batch=augmentation_function([ld_img_batch,ld_label_batch],dt)
            unld_img_batch=augmentation_function([unld_img_batch],dt,labels_present=0)

        ld_img_batch_tmp=np.copy(ld_img_batch)
        # Compute 1 hot encoding of the segmentation mask labels
        ld_label_batch_1hot = sess.run(df_ae['y_tmp_1hot'],feed_dict={df_ae['y_tmp']:ld_label_batch})    

        if(epoch_i>=seg_tr_limit):
            # sample the batch of images and apply deformation field generated by the Generator network on these which are used for the remaining 9500 epochs
            # Batch comprosed of both deformed image,label pairs and original affine transformed image, label pairs
            ld_label_batch_tmp=np.copy(ld_label_batch)
            ###########################
            ## use Deformation field cGAN to generate additional augmented image,label pairs from labeled samples
            flow_vec,ld_img_batch=sess.run([ae['flow_vec'],ae['y_trans']],\
                                        feed_dict={ae['x_l']: ld_img_batch_tmp, ae['z']:z_samples, ae['train_phase']: False})

            ld_label_batch=sess.run([df_ae['deform_y_1hot']],feed_dict={df_ae['y_tmp']:ld_label_batch,df_ae['flow_v']:flow_vec})
            ld_label_batch=ld_label_batch[0]

            ###########################
            #shuffle the quantity/number of images chosen from deformation cGAN augmented images and rest are original images with conventional affine transformations
            no_orig=np.random.randint(5, high=15)
            ld_img_batch[0:no_orig] = ld_img_batch_tmp[0:no_orig]
            if(params.en_1hot==1):
                ld_label_batch[0:no_orig] = ld_label_batch_1hot[0:no_orig]
            #D: this else statement here baffles me - we will have to look into it at some later point - why would you turn images back from 1hot to
            # regular 1d classes with that argmax? When the ld_label batch assumes 1hot input? and especially in a scenario called 1hot == 0
            else:
                ld_label_batch = np.argmax(ld_label_batch,axis=3)
                ld_label_batch[0:no_orig] = ld_label_batch_tmp[0:no_orig]

            #Pick equal number of images from each category
            # ld_img_batch[0:10]=ld_img_batch_tmp[0:10]
            # ld_label_batch[0:10]=ld_label_batch_1hot[0:10]

        elif(epoch_i<seg_tr_limit):
            # sample only labeled data batches to optimize only Segmentation Network for initial 500 epochs
            ld_img_batch=ld_img_batch
            unld_img_batch=unld_img_batch
            ld_label_batch=ld_label_batch_1hot

        if(epoch_i<seg_tr_limit):
            #Optimize only Segmentation Network for initial 500 epochs
            train_summary, seg_cost, _ =sess.run([ae['seg_summary'], ae['seg_cost'], ae['optimizer_unet_seg']], feed_dict={ae['x']: ld_img_batch, ae['y_l']: ld_label_batch,\
                                       ae['select_mask']: False, ae['train_phase']: True})

            if(b==minibatches-1):
                pred_sf_mask = f1_util.calc_pred_sf_mask_full(sess, ae, ld_img_batch)
                acc = np.mean(f1_util.calc_accuracy(np.argmax(pred_sf_mask, -1),np.argmax(ld_label_batch, -1)))
                seg_cost_epoch = np.append(seg_cost_epoch,seg_cost)
                seg_acc_epoch = np.append(seg_acc_epoch,acc)
                print("Epoch: ", epoch_i, "Seg loss: ", np.mean(seg_cost_epoch), "Seg acc: ", np.mean(seg_acc_epoch))


         #Optimize Generator (G), Discriminator (D) and Segmentation (S) networks for the remaining 9500 epochs       
        if(epoch_i>seg_tr_limit):   

            # update both Generator and Segmentation Net parameters in the framework using total loss value
            train_summary, z_cost, cost_a1_seg, seg_cost, _ =sess.run([ae['train_summary'],\
                                                                       ae['z_cost'],  ae['cost_a1_seg'], ae['seg_cost'], ae['optimizer_l2_both_gen_unet']],\
                                                                       feed_dict={ae['x']: ld_img_batch,ae['x_l']: ld_img_batch,ae['y_l']: ld_label_batch,\
                                       ae['z']:z_samples, ae['x_unl']: unld_img_batch, ae['select_mask']: True, ae['train_phase']: True})
            # update Discriminator Net (D) parameters in the setup using only discriminator loss
            train_summary,_ =sess.run([ae['train_summary'],ae['optimizer_disc']], feed_dict={ae['x']: ld_img_batch, ae['x_l']: ld_img_batch, ae['z']:z_samples,\
                                  ae['y_l']: ld_label_batch,ae['x_unl']: unld_img_batch, ae['select_mask']: True, ae['train_phase']: True})

            if(b==minibatches-1):
                pred_sf_mask = f1_util.calc_pred_sf_mask_full(sess, ae, ld_img_batch)
                acc = np.mean(f1_util.calc_accuracy(np.argmax(pred_sf_mask, -1),np.argmax(ld_label_batch, -1)))
                seg_cost_epoch = np.append(seg_cost_epoch,seg_cost)
                seg_acc_epoch = np.append(seg_acc_epoch,acc)
                g_loss_epoch = np.append(g_loss_epoch, cost_a1_seg)
                d_loss_epoch = np.append(d_loss_epoch, z_cost)
                print("Epoch: ", epoch_i, "Seg loss: ", np.mean(seg_cost_epoch), "Seg acc: ", np.mean(seg_acc_epoch), "Disc loss: ", z_cost, "Gen loss: ", cost_a1_seg,)


    if(epoch_i%val_step_update==0):
        train_writer.add_summary(train_summary, epoch_i)
        train_writer.flush()

    if(epoch_i%val_step_update==0):
        ##Save the model with best DSC for Validation Image
        f1_arr=[]

        # Compute segmentation mask and dice_score validation data
        val_pred_sf_mask = f1_util.calc_pred_sf_mask_full(sess, ae, val_imgs)
        f1_val = np.mean(f1_util.calc_f1_score(np.argmax(val_pred_sf_mask, -1),val_labels))
        val_f1_epoch = np.append(val_f1_epoch, f1_val)
        val_acc = np.mean(f1_util.calc_accuracy(np.argmax(val_pred_sf_mask, -1),val_labels))
        val_acc_epoch = np.append(val_acc_epoch, val_acc)

        print("Epoch: ", epoch_i, "F1 (val): ", f1_val, "Acc (val): ", val_acc)
        


        # if (f1_val-f1_val_prev>threshold_f1 and epoch_i!=start_epoch):
        #     print("prev f1_val; present_f1_val", f1_val_prev, f1_val)
        #     f1_val_prev = f1_val
        #     # to save the best model with maximum dice score over the entire n_epochs
        #     print("best model saved at epoch no. ", epoch_i)
        #     mp_best = str(best_model_dir) + str(checkpoint_filename) + '_best_model_epoch_' + str(epoch_i) + '.ckpt'
        #     saver.save(sess, mp_best)

        # # calc. and save validation image dice summary
        # dsc_summary_msg = sess.run(ae['val_f1_summary'], feed_dict={ae['f1']:f1_val})
        # val_sum_writer.add_summary(dsc_summary_msg, epoch_i)
        # val_sum_writer.flush()

    if ((epoch_i==n_epochs-1) and (epoch_i != start_epoch)):
        # model saved at last epoch
        mp = str(save_dir) + str(checkpoint_filename) + '_epochs_' + str(epoch_i) + '.ckpt'
        saver.save(sess, mp)
        try:
            mp_best
        except NameError:
            mp_best=mp

sess.close()
######################################
# restore best model and predict segmentations on test subjects
saver_new = tf.train.Saver()
sess_new = tf.Session(config=config)
saver_new.restore(sess_new, mp_best)
print("best model chkpt name",mp_best)
print("Model restored")

In [None]:
matplotlib.use('TkAgg')
import itertools

In [None]:
fig, ax1 = plt.subplots()

ax1.set_xlabel('epoch')
ax1.set_ylabel('Batch Segmentation Cost', color='red')
ln1 = ax1.plot(seg_cost_epoch, color='red', label='SegCost')
ax1.tick_params(axis='y', labelcolor='red')

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

ax2.set_ylabel('F1 Validation Score', color='blue')  # we already handled the x-label with ax1
ln2 = ax2.plot(seg_acc_epoch, color ='green', label = 'Seg Acc')
#D: this is a hacky way of getting the validation data scores that are only calculated every few epochs to show properly on this graph
# it basically just repeats every value val_step_update number of times, so that it is flat for val_step_update epochs and thus works on the same x axis
ln3 = ax2.plot([i for b in map(lambda x:[x] if not isinstance(x, list) else x, [list(itertools.repeat(x,val_step_update)) for x in val_f1_epoch]) for i in b], color='blue', label = "F1score")
ln4 = ax2.plot([i for b in map(lambda x:[x] if not isinstance(x, list) else x, [list(itertools.repeat(x,val_step_update)) for x in val_acc_epoch]) for i in b], color="orange", label = "Valid. Acc.")
ax2.tick_params(axis='y', labelcolor='blue')

fig.legend(loc="center right", bbox_to_anchor=(1,0.9), bbox_transform=ax1.transAxes)

plt.title("Tree GAN Training (2k Epochs). Medium Dataset")

fig.tight_layout()  # otherwise the right y-label is slightly clipped
#plt.savefig('../data/FES Team/Naive CCN Training.png')
plt.show()

In [None]:
np.argmax(val_acc_epoch)

In [None]:
val_acc_epoch[196]

In [None]:
fig2, ax1 = plt.subplots()

ax1.set_xlabel('epoch')
ax1.set_ylabel('Batch Generator Loss', color='red')
ln1 = ax1.plot(g_loss_epoch, color='red', label='Gen Loss')
ax1.tick_params(axis='y', labelcolor='red')

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

ax2.set_ylabel('Batch Discrimanotr Loss', color='blue')  # we already handled the x-label with ax1
ln2 = ax2.plot(d_loss_epoch, color ='green', label = 'Disc Loss')
ax2.tick_params(axis='y', labelcolor='green')

fig2.legend(loc="center right", bbox_to_anchor=(1,0.9), bbox_transform=ax1.transAxes)

plt.title("GAN Losses")

fig2.tight_layout()  # otherwise the right y-label is slightly clipped
#plt.savefig('../data/FES Team/Naive CCN Training.png')
plt.show()

In [None]:
import matplotlib

In [None]:
def np_normalize(array):
    x_min = np.min(array)
    x_range = np.max(array) - np.min(array)  
    n = np.prod(array.shape[:-1])
    dim = array.shape
    return (array - x_min) / x_range

##

deformed_images = np_normalize(np.copy(ld_img_batch[no_orig:])[:,1:15,1:15,2:5])
deformed_labels = np_normalize(np.argmax(np.copy(ld_label_batch[no_orig:]), -1))

real_images = np_normalize(np.copy(ld_img_batch_tmp[no_orig:])[:,1:15,1:15,2:5])
real_labels = np_normalize(np.copy(ld_label_batch_tmp[no_orig:]))

##

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

def highlight_cell(x,y, ax=None, **kwargs):
    rect = plt.Rectangle((x-.5, y-.5), 1,1, fill=False, **kwargs)
    ax = ax or plt.gca()
    ax.add_patch(rect)
    return rect

# Plot Images along chosen axes
def plot_deformed_imgs(deformed_images, deformed_labels, real_images, real_labels, plot_ax=(7,2)):
    
    assert len(deformed_images) == plot_ax[0], "dimensions do not match"
    
    plt.figure(figsize=(10,20))
    gs1 = gridspec.GridSpec(plot_ax[0], plot_ax[1])
    gs1.update(wspace=0.5, hspace=0.2) # set the spacing between axes. 
    
    i = 0 
    for n in range(plot_ax[0]):
        plt.subplot(gs1[i])
        plt.imshow(real_images[n])
        for x in range(real_labels[n].shape[0]):
            for y in range(real_labels[n].shape[1]):
                if(real_labels[n][y,x] == 1):
                    highlight_cell(x,y,color="red", linewidth=0.5, alpha = 1)
        plt.title('True RGB Image')

        plt.subplot(gs1[i+1])
        plt.imshow(deformed_images[n])
        for x in range(deformed_labels[n].shape[0]):
            for y in range(deformed_labels[n].shape[1]):
                if(deformed_labels[n][y,x] == 1):
                    highlight_cell(x,y,color="red", linewidth=0.5, alpha = 1)
        plt.title('Deformed RGB')
        i += 2
    plt.savefig("deformed.png", dpi=300)
    plt.show()

##

plot_deformed_imgs(deformed_images, deformed_labels, real_images, real_labels, plot_ax=(8,2))

In [None]:
len(deformed_images)

In [None]:
import tensorboard

In [None]:
logs_path

In [None]:
%load_ext tensorboard.notebook
%tensorboard --logdir data_aug_seg/models/dabes/tflogs/

In [None]:
#########################
# To compute inference on test images on the model that yields best dice score on validation images
f1_util.pred_segs_acdc_test_subjs(sess_new,ae,save_dir,orig_img_dt,test_list,struct_name)
#########################
# To plot the generated augmented images from the trained deformation cGAN
for j in range(0,5):
    z_samples,ld_img_batch,unld_img_batch=get_samples(train_imgs,unlabeled_imgs)
    save_dir_tmp=str(save_dir)+'/ep_best_model/'
    plt_func(sess_new,ae,save_dir_tmp,z_samples,ld_img_batch,unld_img_batch,index=j)
######################################
#D: we will have to put back some of these validation data references
# To compute inference on validation images on the best model
#save_dir_tmp=str(save_dir)+'/val_imgs/'
#f1_util.pred_segs_acdc_test_subjs(sess_new,ae,save_dir_tmp,orig_img_dt,val_list,struct_name)
######################################

# Train Additive Intensity Field GAN

In [None]:
ra_en_val=params.ra_en
if(params.ra_en==1):
    params.ra_en=True
else:
    params.ra_en=False


if params.dataset == 'acdc':
    #print('load acdc configs')
    import experiment_init.init_acdc as cfg
    import experiment_init.data_cfg_acdc as data_list
else:
    raise ValueError(params.dataset)

######################################
# class loaders
# ####################################
#  load dataloader object
from dataloaders import dataloaderObj
dt = dataloaderObj(cfg)

if params.dataset == 'acdc':
    #print('set acdc orig img dataloader handle')
    orig_img_dt=dt.load_acdc_imgs

#  load model object
from models import modelObj
model = modelObj(cfg)
#  load f1_utils object
from f1_utils import f1_utilsObj
f1_util = f1_utilsObj(cfg,dt)

######################################
#define save_dir for the model
save_dir=str(cfg.base_dir)+'/models/'+str(params.dataset)+'/tr_intensity_cgan_unet/ra_en_'+str(ra_en_val)+'_gantype_'+str(params.gan_type)+'/'

if(params.data_aug_seg==0):
    save_dir=str(save_dir)+'no_data_aug/'
    cfg.aug_en=params.data_aug_seg
else:
    save_dir=str(save_dir)+'with_data_aug/'

save_dir=str(save_dir)+'lamda_dsc_'+str(params.lamda_dsc)+'_lamda_adv_'+str(params.lamda_adv)+'_lamda_i_'+str(params.lamda_l1_i)+'/'+\
         str(params.no_of_tr_imgs)+'/'+str(params.comb_tr_imgs)+'_v'+str(params.ver)+\
         '/unet_model_beta1_'+str(params.beta_val)+'_lr_seg_'+str(params.lr_seg)+'_lr_gen_'+str(params.lr_gen)+'_lr_disc_'+str(params.lr_disc)+'/'

print('save_dir',save_dir)
######################################

######################################
# load train and val images
train_list = data_list.train_data(params.no_of_tr_imgs,params.comb_tr_imgs)
# load train data cropped images directly
print('loading train imgs')
train_imgs,train_labels = dt.load_img_labels(train_list)

if(params.no_of_tr_imgs=='tr1'):
    train_imgs_copy=np.copy(train_imgs)
    train_labels_copy=np.copy(train_labels)
    while(train_imgs.shape[2]<cfg.batch_size):
        train_imgs=np.concatenate((train_imgs,train_imgs_copy),axis=2)
        train_labels=np.concatenate((train_labels,train_labels_copy),axis=2)
    del train_imgs_copy,train_labels_copy

val_list = data_list.val_data()
# load both val data and its cropped images
print('loading val imgs')
val_label_orig,val_img_crop,val_label_crop,pixel_val_list=load_val_imgs(val_list,dt,orig_img_dt)

# load unlabeled images
unl_list = data_list.unlabeled_data()
print('loading unlabeled imgs')
unlabeled_imgs=dt.load_img_labels(unl_list,label_present=0)

# get test list
print('get test imgs list')
test_list = data_list.test_data()
struct_name=cfg.struct_name
val_step_update=cfg.val_step_update
######################################

######################################

def get_samples(labeled_imgs,unlabeled_imgs):
    # sample z vectors from Gaussian Distribution
    z_samples = np.random.normal(loc=0.0, scale=1.0, size=(cfg.batch_size, params.z_lat_dim)).astype(np.float32)

    #sample Unlabeled data shuffled batch
    unld_img_batch=shuffle_minibatch([unlabeled_imgs],batch_size=int(cfg.batch_size),num_channels=cfg.num_channels,labels_present=0,axis=2)

    #sample Labelled data shuffled batch
    ld_img_batch=shuffle_minibatch([labeled_imgs],batch_size=int(cfg.batch_size),num_channels=cfg.num_channels,labels_present=0,axis=2)

    return z_samples,ld_img_batch,unld_img_batch

def plt_func(sess,ae,save_dir,z_samples,ld_img_batch,unld_img_batch,index=0):
    # plot intensity transformed images for an fixed input image and different sampled z values
    ld_img_tmp=np.zeros_like(ld_img_batch)
    # select one 2D image from the batch and apply different z's sampled over this selected image
    for i in range(0,20):
        ld_img_tmp[i,:,:,0]=ld_img_batch[index,:,:,0]

    int_vec,y_int_deformed,z_cost=sess.run([ae['int_c1'],ae['y_int'],ae['z_cost']], feed_dict={ae['x']: ld_img_tmp, ae['z']:z_samples,\
                          ae['x_unl']: unld_img_batch, ae['select_mask']: True, ae['train_phase']: False})

    f1_util.plot_intensity_transformed_imgs(ld_img_tmp,y_int_deformed,int_vec,save_dir,index=index)

    # Plot gif of all the transformed images generated for the fixed input image
    #f1_util.write_gif_func(ip_img=y_int_deformed, imsize=(cfg.img_size_x,cfg.img_size_y),save_dir=save_dir,index=index)

######################################
# Define checkpoint file to save CNN architecture and learnt hyperparameters
checkpoint_filename='unet_'+str(params.dataset)
logs_path = str(save_dir)+'tensorflow_logs/'
best_model_dir=str(save_dir)+'best_model/'
######################################

######################################
# Define additive intensity field generator model graph
ae = model.intensity_transform_cgan_unet(learn_rate_gen=params.lr_gen,learn_rate_disc=params.lr_disc,\
                        beta1_val=params.beta_val,gan_type=params.gan_type,ra_en=params.ra_en,\
                        learn_rate_seg=params.lr_seg,dsc_loss=params.dsc_loss,en_1hot=params.en_1hot,\
                        lamda_dsc=params.lamda_dsc,lamda_adv=params.lamda_adv,lamda_l1_i=params.lamda_l1_i)

######################################
#  training parameters
start_epoch=0
n_epochs = 10000
disp_step=400
print_step=2000
# no of iterations to train just the segmentation network using the labeled data without any cGAN generated data
seg_tr_limit=400
mean_f1_val_prev=0.1
threshold_f1=0.00001
pathlib.Path(best_model_dir).mkdir(parents=True, exist_ok=True)
######################################

######################################
# define graph to compute 1 hot encoding for an input label
df_ae= model.deform_net()
######################################

######################################
#writer for train summary
train_writer = tf.summary.FileWriter(logs_path)
#writer for dice score and val summary
dsc_writer = tf.summary.FileWriter(logs_path)
val_sum_writer = tf.summary.FileWriter(logs_path)
######################################

######################################
# create a session and initialize variable to use the graph
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
# Save training data
saver = tf.train.Saver(max_to_keep=2)
######################################

# Run for n_epochs
for epoch_i in range(start_epoch,n_epochs):

    # sample z's from Gaussian Distribution
    z_samples = np.random.normal(loc=0.0, scale=1.0, size=(cfg.batch_size, params.z_lat_dim)).astype(np.float32)

    # sample Unlabeled shuffled batch
    unld_img_batch=shuffle_minibatch([unlabeled_imgs],batch_size=int(cfg.batch_size),num_channels=cfg.num_channels,labels_present=0,axis=2)

    # sample Labeled shuffled batch
    ld_img_batch,ld_label_batch=shuffle_minibatch([train_imgs,train_labels],batch_size=cfg.batch_size,num_channels=cfg.num_channels,axis=2)

    if(cfg.aug_en==1):
        # Apply affine transformations
        ld_img_batch,ld_label_batch=augmentation_function([ld_img_batch,ld_label_batch],dt)
        unld_img_batch=augmentation_function([unld_img_batch],dt,labels_present=0)

    ld_img_batch_tmp=np.copy(ld_img_batch)
    # Compute 1 hot encoding of the segmentation mask labels
    ld_label_batch_1hot = sess.run(df_ae['y_tmp_1hot'],feed_dict={df_ae['y_tmp']:ld_label_batch})

    if(epoch_i>=seg_tr_limit):
        # sample the batch of images and apply deformation field generated by the Generator network on these which are used for the remaining 9500 epochs
        # Batch comprosed of both deformed image,label pairs and original affine transformed image, label pairs
        # Here, the labels do not change on application of intensity transformation since it is an additive operation
        ld_label_batch_tmp=np.copy(ld_label_batch)
        ###########################
        # use additive intensity field cGAN to generate additional augmented image,label pairs from labeled samples
        _,ld_img_batch=sess.run([ae['int_c1'],ae['y_int']],\
                                    feed_dict={ae['x']: ld_img_batch_tmp, ae['z']:z_samples, ae['train_phase']: False})
        ld_label_batch=ld_label_batch_1hot

        ###########################
        # shuffle the quantity/number of images chosen from intensity field cGAN augmented images and rest are original images with conventional affine transformations
        no_orig=np.random.randint(5, high=15)
        ld_img_batch[0:no_orig] = ld_img_batch_tmp[0:no_orig]
        if(params.en_1hot==1):
            ld_label_batch = ld_label_batch_1hot
        else:
            ld_label_batch = ld_label_batch_tmp

        #Pick equal number of images from each category
        # ld_img_batch[0:10]=ld_img_batch_tmp[0:10]
        # ld_label_batch[0:10]=ld_label_batch_1hot[0:10]

    elif(epoch_i<seg_tr_limit):
        # sample only labeled data batches to optimize only Segmentation Network for initial 500 epochs
        ld_img_batch=ld_img_batch
        unld_img_batch=unld_img_batch
        ld_label_batch=ld_label_batch_1hot

    if(epoch_i<seg_tr_limit):
        #Optimize only Segmentation Network for initial 500 epochs
        train_summary,_ =sess.run([ae['seg_summary'],ae['optimizer_unet_seg']], feed_dict={ae['x']: ld_img_batch, ae['y_l']: ld_label_batch,\
                                   ae['select_mask']: False, ae['train_phase']: True})

    if(epoch_i>seg_tr_limit):
        #Optimize Generator (G), Discriminator (D) and Segmentation (S) networks for the remaining 9500 epochs

        # update both Generator and Segmentation Net parameters in the framework using total loss value
        train_summary,_ =sess.run([ae['train_summary'],ae['optimizer_l2_both_gen_unet']], feed_dict={ae['x']: ld_img_batch,ae['y_l']: ld_label_batch,\
                                   ae['z']:z_samples, ae['x_unl']: unld_img_batch, ae['select_mask']: True, ae['train_phase']: True})

        # update Discriminator Net (D) parameters in the setup using only discriminator loss
        train_summary,_ =sess.run([ae['train_summary'],ae['optimizer_disc']], feed_dict={ae['x']: ld_img_batch,ae['z']:z_samples,\
                              ae['y_l']: ld_label_batch,ae['x_unl']: unld_img_batch, ae['select_mask']: True, ae['train_phase']: True})

    if(epoch_i%val_step_update==0):
        train_writer.add_summary(train_summary, epoch_i)
        train_writer.flush()

    if(epoch_i%val_step_update==0):
        ##Save the model with best DSC for Validation Image
        mean_f1_arr=[]
        f1_arr=[]
        for val_id_no in range(0,len(val_list)):
            val_img_crop_tmp=val_img_crop[val_id_no]
            val_label_crop_tmp=val_label_crop[val_id_no]
            val_label_orig_tmp=val_label_orig[val_id_no]
            pixel_size_val=pixel_val_list[val_id_no]

            # Compute segmentation mask and dice_score for each validation subject
            pred_sf_mask = f1_util.calc_pred_sf_mask_full(sess, ae, val_img_crop_tmp)
            re_pred_mask_sys,f1_val = f1_util.reshape_img_and_f1_score(pred_sf_mask, val_label_orig_tmp, pixel_size_val)

            #concatenate dice scores of each val image
            mean_f1_arr.append(np.mean(f1_val[1:cfg.num_classes]))
            f1_arr.append(f1_val[1:cfg.num_classes])

        #avg mean over 2 val subjects
        mean_f1_arr = np.asarray(mean_f1_arr)
        mean_f1 = np.mean(mean_f1_arr)
        f1_arr = np.asarray(f1_arr)

        if ((epoch_i%disp_step == 0) or (epoch_i==n_epochs-1)):
            print('mean_f1',epoch_i, mean_f1)
        if (mean_f1-mean_f1_val_prev>threshold_f1 and epoch_i!=start_epoch):
            print("prev f1_val; present_f1_val", mean_f1_val_prev, mean_f1, mean_f1_arr)
            mean_f1_val_prev = mean_f1
            # to save the best model with maximum dice score over the entire n_epochs
            print("best model saved at epoch no. ", epoch_i)
            mp_best = str(best_model_dir) + str(checkpoint_filename) + '_best_model_epoch_' + str(epoch_i) + '.ckpt'
            saver.save(sess, mp_best)

        #calc. and save validation image dice summary
        dsc_summary_msg = sess.run(ae['val_dsc_summary'], feed_dict={ae['rv_dice']:np.mean(f1_arr[:,0]),\
                                ae['myo_dice']:np.mean(f1_arr[:,1]),ae['lv_dice']:np.mean(f1_arr[:,2]),ae['mean_dice']: mean_f1})

    if ((epoch_i==n_epochs-1) and (epoch_i != start_epoch)):
        # model saved at last epoch
        mp = str(save_dir) + str(checkpoint_filename) + '_epochs_' + str(epoch_i) + '.ckpt'
        saver.save(sess, mp)
        try:
            mp_best
        except NameError:
            mp_best=mp

sess.close()
######################################
# restore best model and predict segmentations on test subjects
saver_new = tf.train.Saver()
sess_new = tf.Session(config=config)
saver_new.restore(sess_new, mp_best)
print("best model chkpt",mp_best)
print("Model restored")

#########################
# To compute inference on test images on the model that yields best dice score on validation images
f1_util.pred_segs_acdc_test_subjs(sess_new,ae,save_dir,orig_img_dt,test_list,struct_name)
#########################
# To plot the generated augmented images from the trained deformation cGAN
for j in range(0,5):
    z_samples,ld_img_batch,unld_img_batch=get_samples(train_imgs,unlabeled_imgs)
    save_dir_tmp=str(save_dir)+'/ep_best_model/'
    plt_func(sess_new,ae,save_dir_tmp,z_samples,ld_img_batch,unld_img_batch,index=j)
######################################
# To compute inference on validation images on the best model
save_dir_tmp=str(save_dir)+'/val_imgs/'
f1_util.pred_segs_acdc_test_subjs(sess_new,ae,save_dir_tmp,orig_img_dt,val_list,struct_name)
######################################

# Train Unet with trained GANs

In [None]:
ra_en_val=params.ra_en
if(params.ra_en==1):
    params.ra_en=True
else:
    params.ra_en=False

if params.dataset == 'acdc':
    print('load acdc configs')
    import experiment_init.init_acdc as cfg
    import experiment_init.data_cfg_acdc as data_list
else:
    raise ValueError(params.dataset)

######################################
# class loaders
# ####################################
#  load dataloader object
from dataloaders import dataloaderObj
dt = dataloaderObj(cfg)

if params.dataset == 'acdc':
    print('set acdc img dataloader handle')
    orig_img_dt=dt.load_acdc_imgs

#  load model object
from models import modelObj
model = modelObj(cfg)

#  load f1_utils object
from f1_utils import f1_utilsObj
f1_util = f1_utilsObj(cfg,dt)

######################################
#define save_dir for the model
proj_save_name='tr_deform_and_int_cgans_data_aug/ra_en_'+str(ra_en_val)+'_gantype_'+str(params.gan_type)

if(params.data_aug_seg==0):
    save_dir=str(cfg.base_dir)+'/models/'+str(params.dataset)+'/'+str(proj_save_name)+'/no_data_aug/'
    cfg.aug_en=params.data_aug_seg
else:
    save_dir=str(cfg.base_dir)+'/models/'+str(params.dataset)+'/'+str(proj_save_name)+'/with_data_aug/'

save_dir=str(save_dir)+'lamda_dsc_'+str(params.lamda_dsc)+'_lamda_adv_'+str(params.lamda_adv)+\
         '_lamda_g_'+str(params.lamda_l1_g)+'_lamda_i_'+str(params.lamda_l1_i)+\
         '/'+str(params.no_of_tr_imgs)+'/'+str(params.comb_tr_imgs)+'_v'+str(params.ver)+'/unet_model_dsc_loss_'+str(params.dsc_loss)+'_lr_seg_'+str(params.lr_seg)+'/'
print('save_dir',save_dir)
######################################

######################################
# load train and val images
train_list = data_list.train_data(params.no_of_tr_imgs,params.comb_tr_imgs)
#print(train_list)
#load train data cropped images directly
print('loading train imgs')
train_imgs,train_labels = dt.load_img_labels(train_list)

if(params.no_of_tr_imgs=='tr1'):
    train_imgs_copy=np.copy(train_imgs)
    train_labels_copy=np.copy(train_labels)
    while(train_imgs.shape[2]<cfg.batch_size):
        train_imgs=np.concatenate((train_imgs,train_imgs_copy),axis=2)
        train_labels=np.concatenate((train_labels,train_labels_copy),axis=2)
    del train_imgs_copy,train_labels_copy

val_list = data_list.val_data()
#print(val_list)
#load both val data and its cropped images
print('loading val imgs')
val_label_orig,val_img_crop,val_label_crop,pixel_val_list=load_val_imgs(val_list,dt,orig_img_dt)
#print(pixel_val_list)

# get test list
print('get test imgs list')
test_list = data_list.test_data()
struct_name=cfg.struct_name
val_step_update=cfg.val_step_update
######################################

######################################
# Define checkpoint file to save CNN architecture and learnt hyperparameters
checkpoint_filename='unet_'+str(params.dataset)
logs_path = str(save_dir)+'tensorflow_logs/'
best_model_dir=str(save_dir)+'best_model/'
######################################

########################################################################
#load deformation field generator net
########################################################################
# Define the model graph
tf.reset_default_graph()
ae_geo = model.spatial_generator_cgan_unet(learn_rate_gen=params.lr_gen,learn_rate_disc=params.lr_disc,\
                        beta1_val=params.beta_val,gan_type=params.gan_type,ra_en=params.ra_en,\
                        learn_rate_seg=params.lr_seg,dsc_loss=params.dsc_loss,en_1hot=params.en_1hot,\
                        lamda_dsc=params.lamda_dsc,lamda_adv=params.lamda_adv,lamda_l1_g=params.lamda_l1_g)

# define model path
model_path=str(cfg.base_dir)+'/models/'+str(params.dataset)+'/tr_deformation_cgan_unet/ra_en_'+str(ra_en_val)+'_gantype_'+str(params.gan_type)+'/'

if(params.data_aug_seg==0):
    model_path=str(model_path)+'no_data_aug/'
    cfg.aug_en=params.data_aug_seg
else:
    model_path=str(model_path)+'with_data_aug/'

model_path=str(model_path)+'lamda_dsc_'+str(params.lamda_dsc)+'_lamda_adv_'+str(params.lamda_adv)+'_lamda_g_'+str(params.lamda_l1_g)+'/'+\
         str(params.no_of_tr_imgs)+'/'+str(params.comb_tr_imgs)+'_v'+str(params.ver)+\
         '/unet_model_beta1_'+str(params.beta_val)+'_lr_seg_'+str(params.lr_seg)+'_lr_gen_'+str(params.lr_gen)+'_lr_disc_'+str(params.lr_disc)+'/'

mp=get_max_chkpt_file(model_path)
print('loading deformation field cGAN checkpoint file',mp)
# create a session and load the parameters learned
saver_geo = tf.train.Saver(max_to_keep=2)
sess_geo = tf.Session(config=config)
saver_geo.restore(sess_geo,mp)
######################################

########################################################################
#load additive intensity field generator net
########################################################################
# Define the model graph
tf.reset_default_graph()
ae_int = model.intensity_transform_cgan_unet(learn_rate_gen=params.lr_gen,learn_rate_disc=params.lr_disc,\
                        beta1_val=params.beta_val,gan_type=params.gan_type,ra_en=params.ra_en,\
                        learn_rate_seg=params.lr_seg,dsc_loss=params.dsc_loss,en_1hot=params.en_1hot,\
                        lamda_dsc=params.lamda_dsc,lamda_adv=params.lamda_adv,lamda_l1_i=params.lamda_l1_i)

# define model path
model_path=str(cfg.base_dir)+'/models/'+str(params.dataset)+'/tr_intensity_cgan_unet/ra_en_'+str(ra_en_val)+'_gantype_'+str(params.gan_type)+'/'

if(params.data_aug_seg==0):
    model_path=str(model_path)+'no_data_aug/'
    cfg.aug_en=params.data_aug_seg
else:
    model_path=str(model_path)+'with_data_aug/'

model_path=str(model_path)+'lamda_dsc_'+str(params.lamda_dsc)+'_lamda_adv_'+str(params.lamda_adv)+'_lamda_i_'+str(params.lamda_l1_i)+'/'+\
         str(params.no_of_tr_imgs)+'/'+str(params.comb_tr_imgs)+'_v'+str(params.ver)+\
         '/unet_model_beta1_'+str(params.beta_val)+'_lr_seg_'+str(params.lr_seg)+'_lr_gen_'+str(params.lr_gen)+'_lr_disc_'+str(params.lr_disc)+'/'

mp=get_max_chkpt_file(model_path)
print('loading additive intensity field cGAN checkpoint file ',mp)
# create a session and load the parameters learned
saver_int = tf.train.Saver(max_to_keep=2)
sess_int = tf.Session(config=config)
saver_int.restore(sess_int,mp)

######################################

######################################
#  training parameters
start_epoch=0
n_epochs = 10000
disp_step=500
mean_f1_val_prev=0.1
threshold_f1=0.00001
debug_en=0
pathlib.Path(best_model_dir).mkdir(parents=True, exist_ok=True)
######################################

######################################
# define current graph - unet
tf.reset_default_graph()
ae = model.unet(learn_rate_seg=params.lr_seg,en_1hot=params.en_1hot,dsc_loss=params.dsc_loss)
######################################

######################################
# define deformations net for labels
df_ae= model.deform_net()
######################################

######################################
#writer for train summary
train_writer = tf.summary.FileWriter(logs_path)
#writer for dice score and val summary
dsc_writer = tf.summary.FileWriter(logs_path)
val_sum_writer = tf.summary.FileWriter(logs_path)
######################################

######################################
# create a session and initialize variable to use the graph
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
# Save training data
saver = tf.train.Saver(max_to_keep=2)
######################################

# Run for n_epochs
for epoch_i in range(start_epoch,n_epochs):

    # sample z's from Gaussian Distribution
    z_samples = np.random.normal(loc=0.0, scale=1.0, size=(cfg.batch_size, params.z_lat_dim)).astype(np.float32)

    #sample Labelled data shuffled batch
    ld_img_batch,ld_label_batch=shuffle_minibatch([train_imgs,train_labels],batch_size=cfg.batch_size,num_channels=cfg.num_channels,axis=2)
    if(cfg.aug_en==1):
        # Apply affine transformations
        ld_img_batch,ld_label_batch=augmentation_function([ld_img_batch,ld_label_batch],dt)

    ld_img_batch_orig_tmp=np.copy(ld_img_batch)
    ld_label_batch_orig_tmp=np.copy(ld_label_batch)
    # Compute 1 hot encoding of the segmentation mask labels
    ld_label_batch_orig_1hot = sess.run(df_ae['y_tmp_1hot'],feed_dict={df_ae['y_tmp']:ld_label_batch_orig_tmp})

    ############################
    ## use Deformation field cGAN to generate additional augmented image,label pairs from labeled samples
    flow_vec,ld_img_batch_geo=sess_geo.run([ae_geo['flow_vec'],ae_geo['y_trans']],\
                                feed_dict={ae_geo['x_l']: ld_img_batch_orig_tmp, ae_geo['z']:z_samples, ae_geo['train_phase']: False})

    ld_label_batch_geo=sess.run([df_ae['deform_y_1hot']],feed_dict={df_ae['y_tmp']:ld_label_batch_orig_tmp,df_ae['flow_v']:flow_vec})
    ld_label_batch_geo=ld_label_batch_geo[0]

    ############################
    # use additive Intensity field cGAN to generate additional augmented image,label pairs from labeled samples
    int_c1,ld_img_batch_int=sess_int.run([ae_int['int_c1'],ae_int['y_int']], feed_dict={ae_int['x']: ld_img_batch_orig_tmp, ae_int['z']:z_samples, ae_int['train_phase']: False})
    ld_label_batch_int = ld_label_batch_orig_1hot

    ############################
    # use additive intensity field cGAN over augmented images generated from deformation field cGAN to create augmented images \
    # that have both spatial deformations and intensity transformations applied in them
    ld_img_batch_geo_tmp=np.copy(ld_img_batch_geo)
    int_c1,ld_img_batch_geo_int=sess_int.run([ae_int['int_c1'],ae_int['y_int']], feed_dict={ae_int['x']: ld_img_batch_geo_tmp, ae_int['z']:z_samples, ae_int['train_phase']: False})
    ld_label_batch_geo_int = np.copy(ld_label_batch_geo)

    # shuffle the quantity/number of images chosen from 
    # deformation field cGAN --> no_g,
    # intensity field cGAN   --> no_i,
    # both cGANs             --> no_b,
    # and rest (batch_size - (no_g+no_i+no_b)) are original images with conventional affine transformations.
    no_g=np.random.randint(1, high=5)
    no_i=np.random.randint(5, high=10)
    no_b=np.random.randint(10, high=15)

    ld_img_batch=ld_img_batch_orig_tmp
    ld_label_batch=ld_label_batch_orig_1hot

    ld_img_batch[0:no_g] = ld_img_batch_geo[0:no_g]
    ld_label_batch[0:no_g] = ld_label_batch_geo[0:no_g]
    ld_img_batch[no_g:no_i] = ld_img_batch_int[no_g:no_i]
    ld_label_batch[no_g:no_i] = ld_label_batch_int[no_g:no_i]
    ld_img_batch[no_i:no_b] = ld_img_batch_geo_int[no_i:no_b]
    ld_label_batch[no_i:no_b] = ld_label_batch_geo_int[no_i:no_b]

    #Optimer over this batch of images
    train_summary,_ =sess.run([ae['train_summary'],ae['optimizer_unet_seg']], feed_dict={ae['x']: ld_img_batch, ae['y_l']: ld_label_batch,\
                               ae['select_mask']: False, ae['train_phase']: True})

    if(epoch_i%val_step_update==0):
        train_writer.add_summary(train_summary, epoch_i)
        train_writer.flush()

    if(epoch_i%val_step_update==0):
        ##Save the model with best DSC for Validation Image
        mean_f1_arr=[]
        f1_arr=[]
        for val_id_no in range(0,len(val_list)):
            val_img_crop_tmp=val_img_crop[val_id_no]
            val_label_crop_tmp=val_label_crop[val_id_no]
            val_label_orig_tmp=val_label_orig[val_id_no]
            pixel_size_val=pixel_val_list[val_id_no]

            # Compute segmentation mask and dice_score for each validation subject
            pred_sf_mask = f1_util.calc_pred_sf_mask_full(sess, ae, val_img_crop_tmp)
            re_pred_mask_sys,f1_val = f1_util.reshape_img_and_f1_score(pred_sf_mask, val_label_orig_tmp, pixel_size_val)

            #concatenate dice scores of each val image
            mean_f1_arr.append(np.mean(f1_val[1:cfg.num_classes]))
            f1_arr.append(f1_val[1:cfg.num_classes])

        #avg mean over 2 val subjects
        mean_f1_arr=np.asarray(mean_f1_arr)
        mean_f1=np.mean(mean_f1_arr)
        f1_arr=np.asarray(f1_arr)

        if (mean_f1-mean_f1_val_prev>threshold_f1 and epoch_i!=start_epoch):
            print("prev f1_val; present_f1_val", mean_f1_val_prev, mean_f1, mean_f1_arr)
            mean_f1_val_prev = mean_f1

            # to save the best model with maximum dice score over the entire n_epochs
            print("best model saved at epoch no. ", epoch_i)
            mp_best = str(best_model_dir) + str(checkpoint_filename) + '_best_model_epoch_' + str(epoch_i) + '.ckpt'
            saver.save(sess, mp_best)

        #calc. and save validation image dice summary
        dsc_summary_msg = sess.run(ae['val_dsc_summary'], feed_dict={ae['rv_dice']:np.mean(f1_arr[:,0]),\
                                ae['myo_dice']:np.mean(f1_arr[:,1]),ae['lv_dice']:np.mean(f1_arr[:,2]),ae['mean_dice']: mean_f1})
        val_sum_writer.add_summary(dsc_summary_msg, epoch_i)
        val_sum_writer.flush()

    if ((epoch_i==n_epochs-1) and (epoch_i != start_epoch)):
        # model saved at last epoch
        mp = str(save_dir) + str(checkpoint_filename) + '_epochs_' + str(epoch_i) + '.ckpt'
        saver.save(sess, mp)
        try:
            mp_best
        except NameError:
            mp_best=mp

sess.close()
######################################
# restore best model and predict segmentations on test subjects
saver_new = tf.train.Saver()
sess_new = tf.Session(config=config)
saver_new.restore(sess_new, mp_best)
print("best model chkpt",mp_best)
print("Model restored")

#########################
# To compute inference on test images on the model that yields best dice score on validation images
f1_util.pred_segs_acdc_test_subjs(sess_new,ae,save_dir,orig_img_dt,test_list,struct_name)
######################################
# To compute inference on validation images on the best model
save_dir_tmp=str(save_dir)+'/val_imgs/'
f1_util.pred_segs_acdc_test_subjs(sess_new,ae,save_dir_tmp,orig_img_dt,val_list,struct_name)
######################################