This code references the following tutorials:


1.   https://keras.io/examples/generative/conditional_gan/
2.   https://www.tensorflow.org/tutorials/generative/dcgan
3. https://keras.io/examples/generative/wgan_gp/
4. https://learnopencv.com/conditional-gan-cgan-in-pytorch-and-tensorflow/
5. https://www.youtube.com/watch?v=IZtv9s_Wx9I&ab_channel=AladdinPersson





# Requirements


If running the optimizer - functionality is depricated


In [None]:
# !pip install wandb -qU &> /dev/nul

In [None]:
# import wandb

In [None]:
# wandb.login()

Load the Training Set

In [None]:
!unzip pre_saved_assests/train_final.zip 

Remove 'other' and erroneous files

In [None]:
!rm -rf content/images_train/other
!rm -rf content/images_train/.ipynb_checkpoints

In [None]:
import os
import tensorflow as tf
from tensorflow import keras 
from keras import layers
from IPython import display
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm import tqdm
from matplotlib import gridspec
import keras.backend as K
import scipy
import pandas as pd
import os.path
from os import path

These directories are needed for storing the ECGs used for visual inspection as well as the images used for FID and KID

In [None]:
if path.exists('Fake') == False:
  os.mkdir('Fake')
  os.mkdir('Fake/AF')
  os.mkdir('Fake/NORMAL')

if path.exists('new_images') == False:
  os.mkdir('new_images')
  os.mkdir('new_images/AF')
  os.mkdir('new_images/NORMAL')

# Models


Minibatch discrimination layer for Keras which is taken from the following repository:  https://notebooks.githubusercontent.com/view/ipynb?browser=chrome&color_mode=auto&commit=81de5e4c30810f436dca7b54c5ac3095eeb3ba69&device=unknown&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f6761727269646f712f67616e2d67756964652f383164653565346333303831306634333664636137623534633561633330393565656233626136392f4d696e6962617463682532306469736372696d696e6174696f6e2e6970796e62&logged_in=false&nwo=garridoq%2Fgan-guide&path=Minibatch+discrimination.ipynb&platform=android&repository_id=193508311&repository_type=Repository&version=100

In [None]:
class MinibatchDiscrimination(tf.keras.layers.Layer): 
    '''
    Minibatch Discrimination custom Keras Layer originally defined in this repositiory:
    https://notebooks.githubusercontent.com/view/ipynb?browser=chrome&color_mode=auto&commit=81de5e4c30810f436dca7b54c5ac3095eeb3ba69&device=unknown&enc_url=68747470733a2f2f7261772e67697468756275736572636f6e74656e742e636f6d2f6761727269646f712f67616e2d67756964652f383164653565346333303831306634333664636137623534633561633330393565656233626136392f4d696e6962617463682532306469736372696d696e6174696f6e2e6970796e62&logged_in=false&nwo=garridoq%2Fgan-guide&path=Minibatch+discrimination.ipynb&platform=android&repository_id=193508311&repository_type=Repository&version=100
    '''
    def __init__(self, num_kernel, dim_kernel,kernel_initializer='glorot_uniform', **kwargs):
        self.num_kernel = num_kernel
        self.dim_kernel = dim_kernel
        self.kernel_initializer = kernel_initializer
        super(MinibatchDiscrimination, self).__init__(**kwargs)

    def build(self, input_shape):

        self.kernel = self.add_weight(name='kernel', 
                                      shape=(input_shape[1], self.num_kernel*self.dim_kernel),
                                      initializer=self.kernel_initializer,
                                      trainable=True)
        super(MinibatchDiscrimination, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        activation = tf.matmul(x, self.kernel)
        activation = tf.reshape(activation, shape=(-1, self.num_kernel, self.dim_kernel))
        #Mi
        tmp1 = tf.expand_dims(activation, 3)
        #Mj
        tmp2 = tf.transpose(activation, perm=[1, 2, 0])
        tmp2 = tf.expand_dims(tmp2, 0)
        
        diff = tmp1 - tmp2
        
        l1 = tf.reduce_sum(tf.math.abs(diff), axis=2)
        features = tf.reduce_sum(tf.math.exp(-l1), axis=2)
        return tf.concat([x, features], axis=1)        
        
        
    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1] + self.num_kernel)
      
    def get_config(self):
        config = super().get_config()
        config['dim_kernel'] =  self.dim_kernel
        config['num_kernel'] = self.num_kernel
        config["kernel_initializer"] = self.kernel_initializer
        return config

In [None]:
def gen_block(x,features,size,stride):
    """A single convolutional block of the generator.

        Parameters
        ----------
        x : Keras.layer object
            The previous layer that you wish to add on to.
        features : int
            Number of features to be used in the convolutional layer.
        size : int
            Size of the kernel to be used.
        stried : int
            Stride parameter controls the stride of the Kernel in the convolutional layer

        Returns
        ------
        x
            This function returns the collection of previous layers concatenated to the layer
            defined in this function.
        """
    x = layers.Conv2DTranspose(features, kernel_size=size, strides= stride, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(
    mean=0.0, stddev=0.02), use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

In [None]:
def define_conditional_generator(n_classes=2, latent_dim = 100):
    """Used to define the conditional generator for all GAN models

        If no arguments are passed, n_classes defaults to 2 and latent_dim
        defaults to 100.

        Parameters
        ----------
        n_classes : int
            Number of classes the generator will be required to
            generate (default is 2)

        latent_dim : int, optional
            The length of the random vector the generator will
            use to produce synthetic ECGs (default is 100)

        Returns
        ------
        model
            The conditional generator.
        
        """
    label = layers.Input(shape=(1,))
    embedding = layers.Embedding(n_classes, latent_dim)(label) #embedding layer based on size of latent vector and number of classes
    dense = layers.Dense(4*4*1)(embedding) # additional dense layer
    label_output = layers.Reshape((4, 4, 1))(dense) # As with the original DCGAN - we start with a 4x4x1 (https://doi.org/10.48550/arxiv.1511.06434)

    latent_vector = layers.Input(shape=(latent_dim,))
    latent_dense = layers.Dense(512 * 4 * 4)(latent_vector)
    latent_dense = layers.ReLU()(latent_dense)
    latent_vector_output = layers.Reshape((4, 4, 512))(latent_dense) # unlike Radford et al., we take an approach more similar to Li et al. 
    #(https://www.sciencedirect.com/science/article/pii/S0020025521013049) and reduce the number of features

    concat = layers.Concatenate()([latent_vector_output, label_output])

    x = gen_block(concat,512,4,2)

    x = gen_block(x,256,4,2)

    x = gen_block(x,128,4,2)

    x = gen_block(x,64,4,2)
    
    x = layers.Conv2DTranspose(1, 4, 2,padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(
    mean=0.0, stddev=0.02), use_bias=False, activation='tanh')(x)
    
    model = tf.keras.Model([latent_vector,  label], x)
    
    return model

In [None]:
def disc_block(x,features,size,stride,clip = False, gp=False):
    """A single convolutional block of the discriminator/critic.

        Parameters
        ----------
        x : Keras.layer object
            The previous layer/s that you wish to add on to.
        features : int
            Number of features to be used in the convolutional layer.
        size : int
            Size of the kernel to be used.
        stried : int
            Stride parameter controls the stride of the Kernel in the 
            convolutional layer
        clip : bool
            Activates the norm clipping functionality. Set to True if this 
            block is to be used in the WCGAN. 
        gp :
            Removes batch normalization. Set to True if the block is to be
            used in the WCGANGP.

        Returns
        ------
        x
            This function returns the collection of previous layers concatenated to the layer
            defined in this function.
    """
    if not clip: 
      x = layers.Conv2D(features, kernel_size=size, strides= stride, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(
      mean=0.0, stddev=0.02), use_bias=False)(x)
    else: # for WCGAN, makes use of Keras's kernal constraint functionality
      x = layers.Conv2D(features,  kernel_size=size, strides= stride, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(
      mean=0.0, stddev=0.02), use_bias=False, kernel_constraint=keras.constraints.min_max_norm(-clip,clip, axis =[0,1,2]))(x)
    if not gp: # we do not use batch normalization for the WCGANGP
      x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    return x

In [None]:
def define_conditional_discriminator(in_shape=(128,128,1), n_classes=2, embedding_dim=100):
    """Used to define the conditional discriminator for the DCCGAN

        If no arguments are passed, input_shape defualts to (128,128,1),n_classes defaults defaults to 2, embedding_dim
        defaults to 100.

        Parameters
        ----------
        in_shape : 3-tuple
            Shape of the real image (this shape must be the same as the image
            being generated). (default is (128,128,1))

        n_classes : int
            Number of classes the generator will be required to
            generate (default is 2)

        latent_dim : int, optional
            The length of the random vector the generator will
            use to produce synthetic ECGs (default is 100)

        Returns
        ------
        model
            The conditional discriminator to be used with the DCCGAN.
        
    """
    label = layers.Input(shape=(1,))
    embedding = layers.Embedding(n_classes, embedding_dim)(label)
    dense = layers.Dense(in_shape[0] * in_shape[1] * in_shape[2])(embedding)
    condition = layers.Reshape((in_shape[0], in_shape[1], in_shape[2]))(dense)
    image = layers.Input(shape=in_shape)
    concat = layers.Concatenate()([image, condition])
    
    noise = layers.GaussianNoise(0.1)(concat) # added Gaussian Noise layer helps avoid mode collapse.

    x = layers.Conv2D(64, kernel_size=4, strides= 2, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(
    mean=0.0, stddev=0.02), use_bias=False)(noise)
    x = layers.LeakyReLU(0.2)(x)

    x = disc_block(x,128,4,3)

    x = disc_block(x,256,4,3)

    x = disc_block(x,512,4,3)

    x = layers.Flatten()(x) 
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(1, activation='sigmoid')(x)

    model = tf.keras.Model([image, label], x)
    
    return model 

In [None]:
def define_conditional_critic(clip,in_shape=(128,128,1), n_classes=2, embedding_dim=100):
    """Used to define the conditional critic for the WCGAN

        If no arguments are passed, input_shape defualts to (128,128,1),n_classes defaults defaults to 2, embedding_dim
        defaults to 100.

        Parameters
        ----------
        clip : int
            A number by which the norms of the gradient will be resticted by.
            Norms will be restricted into the range of (-clip,clip)

        in_shape : 3-tuple
            Shape of the real image (this shape must be the same as the image
            being generated). (default is (128,128,1))

        n_classes : int
            Number of classes the generator will be required to
            generate (default is 2)

        latent_dim : int, optional
            The length of the random vector the generator will
            use to produce synthetic ECGs (default is 100)

        Returns
        ------
        model
            The conditional discriminator to be used with the WCGAN.
        """
    label = layers.Input(shape=(1,))
    embedding = layers.Embedding(n_classes, embedding_dim)(label)
    dense = layers.Dense(in_shape[0] * in_shape[1] * in_shape[2])(embedding)
    condition = layers.Reshape((in_shape[0], in_shape[1], 1))(dense)
    image = layers.Input(shape=in_shape)
    concat = layers.Concatenate()([image, condition])
    
    noise = layers.GaussianNoise(1)(concat) # added Gaussian Noise layer helps avoid mode collapse.

    x = layers.Conv2D(64, kernel_size=4, strides= 2, padding='same', kernel_initializer=tf.keras.initializers.RandomNormal(
    mean=0.0, stddev=0.02),  use_bias=False, kernel_constraint=keras.constraints.min_max_norm(-clip,clip, axis =[0,1,2]) )(noise)
    x = layers.LeakyReLU(0.2)(x)
    

    x = disc_block(x,128,4,3,clip = True)
    x = disc_block(x,256,4,3,clip = True)
    x = disc_block(x,512,4,3,clip = True)

    flatten = layers.Flatten()(x) #3D to 2D
    dropout = layers.Dropout(0.4)(flatten)
    dense = layers.Dense(1)(dropout)

    model = tf.keras.Model([image, label], dense)
    return model 

In [None]:
def define_conditional_critic_gp(in_shape=(128,128,1), n_classes=3, embedding_dim=100):
    """Used to define the conditional discriminator for the DCCGAN

        If no arguments are passed, input_shape defualts to (128,128,1),n_classes defaults defaults to 2, embedding_dim
        defaults to 100.

        Parameters
        ----------
        in_shape : 3-tuple
            Shape of the real image (this shape must be the same as the image
            being generated). (default is (128,128,1))

        n_classes : int
            Number of classes the generator will be required to
            generate (default is 2)

        latent_dim : int, optional
            The length of the random vector the generator will
            use to produce synthetic ECGs (default is 100)

        Returns
        ------
        model
            The conditional discriminator to be used with the WCGANGP.
    """

    label = layers.Input(shape=(1,))
    embedding = layers.Embedding(n_classes, embedding_dim)(label)
    dense = layers.Dense(in_shape[0] * in_shape[1] * in_shape[2])(embedding)
    condition = layers.Reshape((in_shape[0], in_shape[1], in_shape[2]))(dense)
    image = layers.Input(shape=in_shape)
    concat = layers.Concatenate()([image, condition])
    
    x = disc_block(concat, 64,4,2,gp = True)

    x = disc_block(x, 128,4,3,gp = True)

    x = disc_block(x, 256,4,3,gp = True)

    x = disc_block(x, 512,4,3,gp = True)

    flatten = layers.Flatten()(x) #3D to 2D
    # mbd = MinibatchDiscrimination(num_kernel=100, dim_kernel=2)(flatten)
    dropout = layers.Dropout(0.5)(flatten)
    dense = layers.Dense(1)(dropout)

    model = tf.keras.Model([image, label], dense)
    return model 

# TrainGAN


In [None]:
class trainGAN():
  '''
  A class used to run experiments involving GANs.
  Attributes
  ----------
  dataset : str
    The name of a tf.data.Dataset object

  optimizer : str
    The name of a an optimizer: adam, rms, nadam.
    Name is coverted to a tf.keras.optimizers object of the same name.

  num_classes : int
    The number of classes being represented. Should equal the number of classes
    in the dataset. For AF, NORMAL, and OTHER, this should be 3. For AF and 
    NORMAL only, this should be 2.

  embedding_dim : int
    The column dimensions of the embedding layer. Embedding layer has the total
    dimensions = num_classes * embedding_dim.

  latenet_dim : int
    The length of the latenet vector being supplied to the generator.

  batch_size : int
    The batch size used to read in the dataset.

  learning_rate:
    A hyperparameter controlling the learning rate used in the optimizer.
  
  epochs : int
    The number of training cycles
  
  n_critic : int
    The number of iterations the critic will be trained for every train step.
    Applies only to wcgan and wcgangp.

  beta : float
    The Beta 1 parameter for Adam and Nadam
  
  gan_type : str
    dccgan, wcgan, or wcgangp. Relates to the loss function and architectures
    to be used in GAN training

  clip : int
  The limit by which the norm of the gradients of the critc are clipped in the 
  case of the wgan. To ensure the lipschitz 1 constraint is obeyed, this value
  must be set to 1.

  Methods
  -------

  generator_loss(label : Tensor, image : Tensor)
      Used to define the cross entropy loss function of the generator.

  discriminator_loss(label : Tensor, image : Tensor)
      Used to define the cross entropy loss function of the discriminator.

  normalization(image : Tensor)
      Used to normalize a tensor containing an image between (-1,1).

  gradient_penalty(target : Tensor, real_imgs : Tensor, fake_imgs : Tensor)
      Used to calculate a gradient penalty for a particular interpolation
      of a set of real and fake images.

  label_maker(n_classes : int ,ecg_type='NORMAL' : str, num_eg=16 : int)
      A helper function used to generate a set of labels (NORMAL,AF,OTHER)
      to be used in image generation functions. Generates a tensor of zero, one, or 
      two equalling length of the num_eg parameter.

  generate_and_save_images(model : Model, epoch : int, seed : Tensor, ecg_type = 'NORMAL' : str)
      A function used to generate and save a set of images to be used to 
      assess GAN performance from a visual quality standpoint.

  generate_new_images(model : Model, num_eg : int, directory='new_images' : str ,ecg_type = 'NORMAL' : str, batch_size=1024 : int)
      A function used to generate and save a set of images to be used to 
      assess GAN performance using Frechét Inception Distance and Kernel
      Inception Distance. The clean-fid package requires that images are saved
      in a directory.

  train_step_dcgan(images : Tensor, target : Tensor)
      A single training step to be used in conjunction with the DCCGAN.
      This method will NOT combine the loss of the discriminator on real and fake
      samples.

  train_step_dcgan_combine(images : Tensor, target : Tensor)
      A single training step to be used in conjunction with the DCCGAN.
      This method will combine the loss of the discriminator on real and fake
      samples.

  train_step_wgan(images : Tensor, target : Tensor)
      A single training step to be used in conjunction with the WCGAN.

  train_step_wgangp(images : Tensor, target : Tensor)
      A single training step to be used in conjunction with the WCGANGP.

  train()
      The main training loop used for all GAN model.Calls a train_step 
      function deterimend by the gan_type parameter.
  
  '''
  def __init__(self, dataset, optimizer,num_classes, embedding_dim, latent_dim, batch_size, 
               learning_rate, epochs, num_eg=16,n_critic = 3,beta = 0.5, gan_type = 'dccgan' ,clip = 1):
      
      self.batch_size = batch_size
      self.gan_type = gan_type
      self.gp_weight=10.0
      self.dataset = dataset
      
    
      self.num_classes = num_classes
      self.embedding_dim = embedding_dim
      self.latent_dim = latent_dim
      self.learning_rate = learning_rate
      self.epochs = epochs
      self.n_critic = n_critic
      self.clip = clip

      self.binary_cross_entropy = tf.keras.losses.BinaryCrossentropy()

      self.num_eg = 16
      self.seed = tf.random.normal([num_eg, self.latent_dim])
      
      self.gen = define_conditional_generator(n_classes = num_classes, latent_dim=self.latent_dim)
      
      if optimizer == 'adam':
          self.generator_opt = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = beta, beta_2 = 0.999 )
          self.disc_opt = tf.keras.optimizers.Adam(learning_rate = learning_rate, beta_1 = beta, beta_2 = 0.999 )

      if optimizer == 'rms':
          self.generator_opt = tf.keras.optimizers.RMSprop(learning_rate = learning_rate)
          self.disc_opt = tf.keras.optimizers.RMSprop(learning_rate = learning_rate)

      if optimizer == 'nadam':
          self.generator_opt = tf.keras.optimizers.Nadam(learning_rate = learning_rate, beta_1 = beta, beta_2 = 0.999 )
          self.disc_opt= tf.keras.optimizers.Nadam(learning_rate = learning_rate, beta_1 = beta, beta_2 = 0.999 )
      
      if gan_type == 'dccgan':
          self.disc = define_conditional_discriminator(n_classes=num_classes, embedding_dim=self.embedding_dim)
         
      if gan_type == 'wcgan':
          self.disc = define_conditional_critic(self.clip, n_classes=num_classes, embedding_dim=self.embedding_dim)

      if gan_type == 'wcgangp':
          self.disc = define_conditional_critic_gp(n_classes=num_classes, embedding_dim=self.embedding_dim)
      
     
      self.checkpoint_dir = './checkpoints'
      self.checkpoint_prefix = os.path.join(self.checkpoint_dir, "ckpt")
      self.checkpoint = tf.train.Checkpoint(generator_opt=self.generator_opt,
                                disc_opt=self.disc_opt,
                                generator=self.gen,
                                discriminator=self.disc)

  def generator_loss(self,label, image):
      """Used to define the cross entropy loss function of the generator.

        Parameters
        ----------
        label : tf.keras.Tensor
            A tensor of all ones indicating that the associated image is real
            when it is in fact fake

        image : tf.keras.Tensor
            A tensor containing the generated image.

        Returns
        ------
        gen_loss
            The result of applying binary cross entropy loss via the 
            discriminator.
        """
      gen_loss = self.binary_cross_entropy(label, image)
      return gen_loss

  def discriminator_loss(self,label, image):
      """Used to define the cross entropy loss function of the discriminator.

      Parameters
      ----------
      label : tf.keras.Tensor
        A tensor of either all ones indicating that the associated image is 
        real, or all zeroes indicating that the image is fake.
        
      image : tf.keras.Tensor
        A tensor containing the generated image.

      Returns
      ------
      gen_loss
        The result of applying binary cross entropy loss via the 
        discriminator.
      """
      disc_loss = self.binary_cross_entropy(label, image)
      return disc_loss

  @tf.function
  def normalization(self,image):
      """Used to normalize a tensor containing an image between (-1,1).

        Parameters
        ----------
        image : tf.keras.Tensor
            A tensor containing the image/s to be normalized.

        Returns
        ------
        image
            The same tensor normalized between (-1,1)

        References: 
        -----------
            https://learnopencv.com/conditional-gan-cgan-in-pytorch-and-tensorflow/
        
      """
      image = tf.image.resize(image, (128,128))
      image = tf.subtract(tf.divide(image, 127.5),1) 
      return image

############
  def gradient_penalty(self, target, real_imgs, fake_imgs):
      """Used to calculate a gradient penalty for a particular interpolation
      of a set of real and fake images.

        Parameters
        ----------
        target : tf.keras.Tensor
            A tensor of zeroe or one or two indicating that the associated real and genrated 
            images display AF or Normal or OTHER.

        real_imgs : tf.keras.Tensor
            A tensor containing real images.
        
        fake_imgs : tf.keras.Tensor
            A tensor containing fake images.

        Returns
        ------
        gp
            The resulting gradient penalty calculated over the set of real 
            and syntehtic images.
        
        References: 
        -----------
           https://keras.io/examples/generative/wgan_gp/
      """
      epsilon = tf.random.uniform((real_imgs.shape[0],1,1,1))
      interpolated = real_imgs * epsilon + fake_imgs * (1-epsilon)
      with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            pred = self.disc([interpolated,target], training=True)
      grads = gp_tape.gradient(pred, [interpolated])[0]
      norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
      gp = tf.reduce_mean((norm - 1.0) ** 2)
      return gp

#############

  def label_maker(self,n_classes,ecg_type='NORMAL', num_eg=16):
      """A helper function used to generate a set of labels (NORMAL,AF,OTHER)
      to be used in image generation functions. Generates a tensor of zero, one, or 
      two equalling length of the num_eg parameter.

        Parameters
        ----------
        n_classes : int
            not necessary

        ecg_type : str
            A string indicating the type of ECG to generate. Accepted arguments
            are 'NORMAL', 'AF', or 'OTHER'.
        
        num_eg : int
            The number of labels to generate

        Returns
        ------
        labels
            A tensor of labels of either 0,1 or 2 corresponding to AF, NORMAL or
            OTHER
        References: 
        -----------
            https://learnopencv.com/conditional-gan-cgan-in-pytorch-and-tensorflow/
      """
      if ecg_type == 'AF':
          lab = tf.cast(0,  dtype=tf.dtypes.int32)
      elif ecg_type == 'NORMAL':
          lab = tf.cast(1,  dtype=tf.dtypes.int32)
      else:
          lab = tf.cast(2,  dtype=tf.dtypes.int32)
      labels = tf.repeat(lab, [num_eg], axis=None, name=None)

      return labels

  def generate_and_save_images(self,model, epoch, seed, ecg_type = 'NORMAL'):
      """A function used to generate and save a set of images to be used to 
      assess GAN performance from a visual quality standpoint.

      Parameters
      ----------
      model : tf.Model
      The generator to be used to generate images.

      epoch : int
      The current epoch of training

      seed : tf.Tensor
      A fixed tensor of shape num_eg*latent_dim containing random numbers
      drawn from the Gaussian distribution where num_eg is the required
      number of fake ECGs to generate.

      ecg_type : str
      A string indicating the type of ECG to generate. Accepted arguments
      are 'NORMAL', 'AF', or 'OTHER'. Used in conjunction with the label_gen
      function to generate a tenesor of integers corresponding to the
      chosen ECG type.
      References: 
      -----------
      https://learnopencv.com/conditional-gan-cgan-in-pytorch-and-tensorflow/
      """
      labels = self.label_maker(n_classes=2, ecg_type = ecg_type)
      predictions = model([seed, labels], training=False)
      
      for i in range(predictions.shape[0]):
          plt.subplot(4, 4, i+1)
          pred = (predictions[i, :, :, :] + 1 ) * 127.5
          pred = np.array(pred)  
          img = keras.preprocessing.image.array_to_img(pred)
          plt.imshow(img, cmap='gray')
          plt.axis('off')
          

           # img.save("Fake/{ecg}/generated_img_{i}_{epoch}.png".format(ecg = ecg_type,i=i, epoch=epoch))
      plt.savefig('Fake/' + ecg_type +'/image_at_epoch_{:d}.png'.format(epoch))


          
  def generate_new_images(self, model, num_eg, directory='new_images',ecg_type = 'NORMAL', batch_size=1024):
      """A function used to generate and save a set of images to be used to 
      assess GAN performance using Frechét Inception Distance and Kernel
      Inception Distance. The clean-fid package requires that images are saved
      in a directory.

      Parameters
      ----------
      model : tf.Model
      The generator to be used to generate images.

      num_eg : int
      The total number of ECGs that is required to generate

      directory : string
      The directory to which the new images are saved to (defaults to
      'new_images')

      ecg_type : str
      A string indicating the type of ECG to generate. Accepted arguments
      are 'NORMAL', 'AF', or 'OTHER'. Used in conjunction with the label_gen
      function to generate a tenesor of integers corresponding to the
      chosen ECG type.

      batch_size : int
      This determines how many images will be generated in a single iteration
      to reduce memory requirements.
      """
      limit = num_eg//batch_size
      # self.checkpoint.restore(tf.train.latest_checkpoint(t.checkpoint_dir))
      count = 0
      if num_eg%batch_size!=0:
        print("please ensure batch size and number of examples are divisible")
        return
      
      for i in range(limit):
        input = tf.random.normal([batch_size, self.latent_dim])
        labels = self.label_maker(n_classes=self.num_classes, ecg_type = ecg_type, num_eg = batch_size)

        predictions = model([input, labels], training=False)
        
        for j in range(predictions.shape[0]):
      
            pred = (predictions[j, :, :, :] + 1 ) * 127.5
            pred = np.asarray(pred)  
            img = keras.preprocessing.image.array_to_img(pred)
            img.save("{directory}/{ecg}/generated_{ecg}_{i}.png".format(directory=directory, ecg = ecg_type,i=count))
            count += 1


  @tf.function
  def train_step_test(self,images,target):
      
      noise = tf.random.normal([target.shape[0], self.latent_dim])
     
      with tf.GradientTape() as disc_tape_real:
          generated_images = self.gen([noise,target], training=True)

          real_output = self.disc([images,target], training=True)
          real_targets = tf.ones_like(real_output)
          disc_loss_real = self.discriminator_loss(real_targets, real_output)

          fake_output = self.disc([generated_images,target], training=True)
          fake_targets = tf.zeros_like(fake_output)
          disc_loss_fake = self.discriminator_loss(fake_targets, fake_output)

          disc_loss = 0.5*(disc_loss_real+disc_loss_fake)

      gradients_of_disc_real = disc_tape_real.gradient(disc_loss, self.disc.trainable_variables)
      self.disc_opt.apply_gradients(zip(gradients_of_disc_real,self.disc.trainable_variables))
      

      with tf.GradientTape() as gen_tape:
          generated_images = self.gen([noise,target], training=True)
          fake_output = self.disc([generated_images,target], training=True)
          real_targets = tf.ones_like(fake_output)
          gen_loss = self.generator_loss(real_targets, fake_output)

      gradients_of_gen = gen_tape.gradient(gen_loss, self.gen.trainable_variables)
      self.generator_opt.apply_gradients(zip(gradients_of_gen,self.gen.trainable_variables))

      return  disc_loss_real,  disc_loss_fake, gen_loss
        
  

  @tf.function
  def train_step_dcgan(self,images,target): 
      """A single training step to be used in conjunction with the DCCGAN.

        Parameters
        ----------
        images : tf.Tensor
            A tensor of real images

        target : tf.Tensor
            A list of numbers corresponding to the true label of each image
      """
      noise = tf.random.normal([target.shape[0], self.latent_dim])
     
      with tf.GradientTape() as real_tape, tf.GradientTape() as fake_tape: # Soumith
          generated_images = self.gen([noise,target], training=True)
          
          #give real images with real label
          real_guess = self.disc([images,target], training=True)
          real_ans = tf.ones_like(real_guess)
          real_loss = self.discriminator_loss(real_ans, real_guess)

          #give fake images with fake label
          fake_guess = self.disc([generated_images,target], training=True)
          fake_ans = tf.zeros_like(fake_guess)
          fake_loss = self.discriminator_loss(fake_ans, fake_guess)

          disc_loss = 0.5*(real_loss+fake_loss)

      grad_disc_real = real_tape.gradient(real_loss, self.disc.trainable_variables)
      self.disc_opt.apply_gradients(zip(grad_disc_real,self.disc.trainable_variables))
      
      grad_disc_fake = fake_tape.gradient(fake_loss, self.disc.trainable_variables)
      self.disc_opt.apply_gradients(zip(grad_disc_fake,self.disc.trainable_variables))

      with tf.GradientTape() as gen_tape:
          generated_images = self.gen([noise,target], training=True)
          #give fake images with real label
          fake_guess = self.disc([generated_images,target], training=True)
          real_answer = tf.ones_like(fake_guess)
          gen_loss = self.generator_loss(real_answer, fake_guess)

      grad_gen = gen_tape.gradient(gen_loss, self.gen.trainable_variables)
      self.generator_opt.apply_gradients(zip(grad_gen,self.gen.trainable_variables))

      return  real_loss,  fake_loss, gen_loss

 
  
  @tf.function
  def train_step_wgan(self,images,target):

    """A single training step to be used in conjunction with the WCGAN.

    Parameters
    ----------
    images : tf.Tensor
        A tensor of real images

    target : tf.Tensor
        A list of numbers corresponding to the true label of each image
    """
    for _ in range(self.n_critic):
      noise = tf.random.normal([target.shape[0], self.latent_dim])

      with tf.GradientTape() as disc_tape:
        generated_images = self.gen([noise,target], training=True)
        real_guess = self.disc([images,target], training=True)
        fake_guess = self.disc([generated_images,target], training=True)
        disc_loss = tf.math.reduce_mean(fake_guess) - tf.math.reduce_mean(real_guess)

    gradients_of_disc = disc_tape.gradient(disc_loss, self.disc.trainable_variables)
    self.disc_opt.apply_gradients(zip(gradients_of_disc,self.disc.trainable_variables))

    noise = tf.random.normal([target.shape[0], self.latent_dim])
    with tf.GradientTape() as gen_tape:
      generated_images = self.gen([noise,target], training=True)
      fake_guess = self.disc([generated_images,target], training=True)
      gen_loss = -tf.math.reduce_mean(fake_guess)   
    gradients_of_gen = gen_tape.gradient(gen_loss, self.gen.trainable_variables)
    self.generator_opt.apply_gradients(zip(gradients_of_gen,self.gen.trainable_variables))

    return disc_loss,gen_loss           

  @tf.function
  def train_step_wgan_gp(self,images,target):
      """A single training step to be used in conjunction with the WCGANGP.

      Parameters
      ----------
      images : tf.Tensor
          A tensor of real images

      target : tf.Tensor
          A list of numbers corresponding to the true label of each image
      """
      for _ in range(self.n_critic):
        noise = tf.random.normal([target.shape[0], self.latent_dim])
        
        with tf.GradientTape() as disc_tape:
            generated_images = self.gen([noise,target], training=True)
            real_guess = self.disc([images,target], training=True)
            fake_guess = self.disc([generated_images,target], training=True)
            disc_loss = tf.math.reduce_mean(fake_guess) - tf.math.reduce_mean(real_guess)
            gp = self.gradient_penalty(target,images, generated_images)
            disc_loss = disc_loss + gp * 10
            
        gradients_of_disc = disc_tape.gradient(disc_loss, self.disc.trainable_variables)
        self.disc_opt.apply_gradients(zip(gradients_of_disc,self.disc.trainable_variables))
        
      noise = tf.random.normal([target.shape[0], self.latent_dim])
      with tf.GradientTape() as gen_tape:
          generated_images = self.gen([noise,target], training=True)
          fake_output = self.disc([generated_images,target], training=True)
          gen_loss = -tf.math.reduce_mean(fake_output)   
      gradients_of_gen = gen_tape.gradient(gen_loss, self.gen.trainable_variables)
      self.generator_opt.apply_gradients(zip(gradients_of_gen,self.gen.trainable_variables))
       
      return disc_loss,gen_loss
  
     

  def train(self):
      """The main training loop used for all GAN model.Calls a train_step 
      function deterimend by the gan_type parameter.
      """
      disc_losses = []
      gen_losses = []
      real_embeddings = []
      fake_embeddings = []
      disc_loss = None
      gen_loss = None
      x = None
      y = None
      fid_list = []

      for epoch in range(self.epochs):
          start = time.time()
         
          count = 0
          for image_batch,target in tqdm(self.dataset):
              
              img = tf.cast(image_batch, tf.float32)
              imgs = self.normalization(img)
              
              if self.gan_type == 'dccgan':
                disc1,disc2,gen_loss = self.train_step_dcgan(imgs,target)
                disc_loss = (disc1+disc2)/2

              if self.gan_type == 'wcgan':
                  disc_loss,gen_loss = self.train_step_wgan(imgs,target)

              if self.gan_type == 'wcgangp':
                disc_loss,gen_loss = self.train_step_wgan_gp(imgs,target)



          self.generate_and_save_images(self.gen,
                              epoch + 1,
                              self.seed,
                               'NORMAL')

          self.generate_and_save_images(self.gen,
                              epoch + 1,
                              self.seed,
                              'AF')
          
          
          # af_img = wandb.Image('Fake/AF/image_at_epoch_{:d}.png'.format(epoch+1), caption="AF image")
          # normal_img = wandb.Image('Fake/NORMAL/image_at_epoch_{:d}.png'.format(epoch+1), caption="Normal image")
          
          if epoch % 10 == 0 or epoch == 0:
              print('Saving...')
              self.checkpoint.save(file_prefix = self.checkpoint_prefix)
          

          sess = tf.compat.v1.Session()
          with sess.as_default(): 
              print(disc_loss.numpy())
              disc_losses.append(disc_loss.numpy())
              gen_losses.append(gen_loss.numpy())

              
          print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
          
          # wandb.log({'disc_loss': np.mean(disc_losses), 
          #       'gen_loss': np.mean(gen_losses),
          #       'is_mu':float(x),
          #       'is_sigma':float(y),
          #       'fid': fid,
          #       'af_img' : af_img,
          #       'normal_img': normal_img})
      # # return disc_losses,gen_losses


# Use Cases


Helper function to read in datasets

In [None]:
# RUN ME
def load_grayscale_images(dir, in_shape = (128,128), batch_size = 64):
    """Helper function used to read in an image dataset from directory.

    Parameters
    ----------
    dir : str
        Directory of images. Labels correspond to subfolders.
        
    in_shape : tuple
        Describes the x,y dimentions of a given images (defaults to (128,128)).

    batch_size : int
        The number of images to be read in from the directory at any call to this dataset object.

    Returns
    ------
    dataset
        A tf.data.Dataset object.
    """
    dataset = tf.keras.utils.image_dataset_from_directory(
          dir, 
          labels="inferred",
          label_mode="int",
          image_size = in_shape, 
          batch_size = batch_size,
          color_mode = "grayscale"
        )
    return dataset
  
    

### Select the experiment you want to run

## AFib-GAN:
- Batch Size = 64
- Optimizer = Adam
- Learning Rate = 0.0002
- Embedding Dimension = 100
- Latent Dimennsion = 100


In [None]:
if __name__ == "__main__":
    BATCH_SIZE = 64
    dataset = load_grayscale_images('content/images_train',batch_size=BATCH_SIZE)
    
    t = trainGAN(
        dataset=dataset,
        optimizer = 'adam',
        num_classes= 2,
        embedding_dim= 100,
        latent_dim=100,
        batch_size=BATCH_SIZE,
        learning_rate=0.0002, 
        num_eg = 16, 
        epochs = 100,
        gan_type = 'dccgan',
        n_critic = 3,
        clip = 1
        )
    
    
    t.train()


## WCGAN:
- Batch Size = 64
- Optimizer = RMSprop
- Learning Rate = 0.00005
- Embedding Dimension = 100
- Latent Dimennsion = 100
- Number of times critic is trained = 3

In [None]:
if __name__ == "__main__":
    BATCH_SIZE = 64
    dataset = load_grayscale_images('content/images_train',batch_size=BATCH_SIZE)
    
    t = trainGAN(
        dataset=dataset,
        optimizer = 'rms',
        num_classes= 2,
        embedding_dim= 100,
        latent_dim=100,
        batch_size=BATCH_SIZE,
        learning_rate=0.00005, 
        num_eg = 16, 
        epochs = 100,
        gan_type = 'wcgan',
        n_critic = 3,
        clip = 1
        )
    
    
    t.train()


## WCGANGP-RMSprop:
- Batch Size = 64
- Optimizer = RMSprop
- Learning Rate = 0.00005
- Embedding Dimension = 100
- Latent Dimennsion = 100
- Number of times critic is trained = 3

In [None]:
if __name__ == "__main__":
    BATCH_SIZE = 64
    dataset = load_grayscale_images('content/images_train',batch_size=BATCH_SIZE)
    
    t = trainGAN(
        dataset=dataset,
        optimizer = 'rms',
        num_classes= 2,
        embedding_dim= 100,
        latent_dim=100,
        batch_size=BATCH_SIZE,
        learning_rate=0.00005, 
        num_eg = 16, 
        epochs = 100,
        gan_type = 'wcgangp',
        n_critic = 3,
        clip = 1
        )
    
    
    t.train()


## WCGANGP-Adam:
- Batch Size = 64
- Optimizer = Adam
- Learning Rate = 0.0002
- Embedding Dimension = 100
- Latent Dimennsion = 100
- Number of times critic is trained = 3

In [None]:
if __name__ == "__main__":
    BATCH_SIZE = 64
    dataset = load_grayscale_images('content/images_train',batch_size=BATCH_SIZE)
    
    t = trainGAN(
        dataset=dataset,
        optimizer = 'adam',
        num_classes= 2,
        embedding_dim= 100,
        latent_dim=100,
        batch_size=BATCH_SIZE,
        learning_rate=0.0002, 
        num_eg = 16, 
        epochs = 100,
        gan_type = 'wcgangp',
        n_critic = 3,
        clip = 1
        )
    
    
    t.train()

## Saving Datasets and Generators

## Generate image sets for FID and KID 
and move them into the pre_saved_assets folder

In [None]:
# generate image sets for FID and KID
# >= 10000 images else FID might be underestimated
t.generate_new_images(model = t.gen, num_eg=1024*10,ecg_type = 'AF',
                      batch_size=64)
t.generate_new_images(model = t.gen, num_eg=1024*10,ecg_type = 'NORMAL',
                      batch_size=64)

In [None]:
!zip -r for_testing.zip new_images

In [None]:
%cp -av for_testing.zip pre_saved_assests

## Save the generator
and move it into the pre_saved_assets folder

In [None]:
# save the generator
model = t.gen
model.save('test_gen')

In [None]:
!zip -r est_gen.zip test_gen

In [None]:
%cp -av test_gen.zip pre_saved_assests

# Next:
With the generator trained and the small image sets produced, we can now move on to the evaluation section:     https://colab.research.google.com/drive/1b38zrwvoWTGgBm_PwrP2Sp_XGu9Ixfhg?usp=sharing