# **Conditional Variational autoencoder (VAE) - Toy datasets**

# **Utility functions**

In [None]:
%%capture
%pip install wandb

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!mkdir breast-histopathology
 

In [None]:
%%capture
!unzip "/content/drive/MyDrive/datasets/IDC_regular_ps50_idx5.zip" -d "/content"


In [None]:
!mv IDC_regular_ps50_idx5 breast-histopathology

In [None]:
%cd drive/MyDrive/
%rm -rf H-VAE
!git clone https://github.com/nderus/H-VAE

In [None]:
%cd H-VAE

In [1]:
%pip install wandb
import wandb
from datasets import data_loader

Note: you may need to restart the kernel to use updated packages.


In [2]:
import numpy as np
import time
import matplotlib.pyplot as plt
import random
from sklearn.model_selection import train_test_split
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
print(tf. __version__)

2.7.0


In [3]:
import os
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras import backend
import wandb
from wandb.keras import WandbCallback
from tensorflow.keras import regularizers


In [4]:
K.clear_session()

In [5]:
# TO DO: this should be passed as arguments
dataset_name = 'histo'
model_name = 'GAN'
input_noise_dim=512
epoch_count=50
batch_size=100
learning_rate = 0.0001

In [6]:
def find_indices(lst, condition):
    return np.array([i for i, elem in enumerate(lst) if condition(elem)])
    
def plot_2d_data_categorical(data_2d, y, titles=None, figsize = (7, 7), category_count=10):
  fig, axs = plt.subplots(category_count, len(data_2d), figsize = figsize)
  colors = np.array(['#7FFFD4', '#458B74', '#0000CD', '#EE3B3B', '#7AC5CD', '#66CD00',
         '#EE7621', '#3D59AB', '#CD950C', '#483D8B'])
  for i in range(len(data_2d)):
      for k in range(category_count):

        index = find_indices(y[i], lambda e: e == k)

        data_2d_k = data_2d[i][index, ]
        y_k = y[i][index]

        if (titles != None):
          axs[k,i].set_title("{} - Class: {}".format(titles[i], k))

        scatter = axs[k, i].scatter(data_2d_k[:, 0], data_2d_k[:, 1],
                                s=1, c=colors[k], cmap=plt.cm.Paired)
        axs[k, i].legend(*scatter.legend_elements())
        axs[k, i].set_xlim([-3, 3])
        axs[k, i].set_ylim([-3, 3])
        wandb.log({"Embdedding_classes": wandb.Image(plt)})
        
def plot_2d_data(data_2d, y, titles=None, figsize = (7, 7)):
  _, axs = plt.subplots(1, len(data_2d), figsize = figsize)

  for i in range(len(data_2d)):
    
    if (titles != None):
      axs[i].set_title(titles[i])
    scatter=axs[i].scatter(data_2d[i][:, 0], data_2d[i][:, 1],
                            s=1, c=y[i], cmap=plt.cm.Paired)
    axs[i].legend(*scatter.legend_elements())
    wandb.log({"Embdedding": wandb.Image(plt)})

def plot_history(history,metric=None):
  fig, ax1 = plt.subplots(figsize=(10, 8))

  epoch_count=len(history.history['loss'])

  line1,=ax1.plot(range(1,epoch_count+1),history.history['loss'],
                  label='train_loss',color='orange')
  ax1.plot(range(1,epoch_count+1),history.history['val_loss'],
                  label='val_loss',color = line1.get_color(), linestyle = '--')
  ax1.set_xlim([1,epoch_count])
  ax1.set_ylim([0, max(max(history.history['loss']),
              max(history.history['val_loss']))])
  ax1.set_ylabel('loss',color = line1.get_color())
  ax1.tick_params(axis='y', labelcolor=line1.get_color())
  ax1.set_xlabel('Epochs')
  _=ax1.legend(loc='lower left')

  if (metric!=None):
    ax2 = ax1.twinx()
    line2,=ax2.plot(range(1,epoch_count+1),history.history[metric],
                    label='train_'+metric)
    ax2.plot(range(1,epoch_count+1),history.history['val_'+metric],
                    label='val_'+metric,color = line2.get_color(),
                    linestyle = '--')
    ax2.set_ylim([0, max(max(history.history[metric]),
                max(history.history['val_'+metric]))])
    ax2.set_ylabel(metric,color=line2.get_color())
    ax2.tick_params(axis='y', labelcolor=line2.get_color())
    _=ax2.legend(loc='upper right')

def plot_generated_images(generated_images, nrows, ncols,no_space_between_plots=False, figsize=(20, 20), epoch=None):
  _, axs = plt.subplots(nrows, ncols,figsize=figsize,squeeze=False)

  for i in range(nrows):
    for j in range(ncols):
      axs[i,j].axis('off')
      axs[i,j].imshow((generated_images[i][j]* 255).astype(np.uint8))
  

  if no_space_between_plots:
    plt.subplots_adjust(wspace=0,hspace=0)
  wandb.log({"Generations": wandb.Image(plt, caption="Epoch:{}".format( epoch)) }) #
  plt.show()


In [7]:
def Train_Val_Plot(loss, val_loss, reconstruction_loss, val_reconstruction_loss, kl_loss, val_kl_loss):

    fig, (ax1, ax2, ax3) = plt.subplots(1,4, figsize= (16,4))
    fig.suptitle(" MODEL'S METRICS VISUALIZATION ")

    ax1.plot(range(1, len(loss) + 1), loss)
    ax1.plot(range(1, len(val_loss) + 1), val_loss)
    ax1.set_title('History of Loss')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend(['training', 'validation'])

    ax2.plot(range(1, len(reconstruction_loss) + 1), reconstruction_loss)
    ax2.plot(range(1, len(val_reconstruction_loss) + 1), val_reconstruction_loss)
    ax2.set_title('History of reconstruction_loss')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('reconstruction_loss')
    ax2.legend(['training', 'validation'])
    
    ax3.plot(range(1, len(kl_loss) + 1), kl_loss)
    ax3.plot(range(1, len(val_kl_loss) + 1), val_kl_loss)
    ax3.set_title(' History of kl_loss')
    ax3.set_xlabel(' Epochs ')
    ax3.set_ylabel('kl_loss')
    ax3.legend(['training', 'validation'])
    wandb.log({"Training": wandb.Image(plt)})
    plt.show()

    


In [8]:
def plot_gan_losses(d_losses,g_losses):
  fig, ax1 = plt.subplots(figsize=(10, 8))

  epoch_count=len(d_losses)

  line1,=ax1.plot(range(1,epoch_count+1),d_losses,label='discriminator_loss',color='orange')
  ax1.set_ylim([0, max(d_losses)])
  ax1.tick_params(axis='y', labelcolor=line1.get_color())
  _=ax1.legend(loc='lower left')

  ax2 = ax1.twinx()
  line2,=ax2.plot(range(1,epoch_count+1),g_losses,label='generator_loss')
  ax2.set_xlim([1,epoch_count])
  ax2.set_ylim([0, max(g_losses)])
  ax2.set_xlabel('Epochs')
  ax2.tick_params(axis='y', labelcolor=line2.get_color())
  _=ax2.legend(loc='upper right')

In [9]:
def get_gan_random_input(batch_size,noise_dim,*_):
  return np.random.normal(0, 1, size=(batch_size, noise_dim))

def get_gan_fake_batch(generator,batch_size,generator_input):
  batch_x = generator.predict(generator_input)
  batch_y=np.zeros(batch_size)

  return batch_x,batch_y

In [10]:
def concatenate_cgan_batches(real_batch_x,fake_batch_x):
  batch_input = np.concatenate((real_batch_x[0], fake_batch_x[0]))
  batch_condition_info =np.concatenate((real_batch_x[1], fake_batch_x[1]))

  return [batch_input,batch_condition_info]

In [11]:
def chunks(list, n):
    for i in range(0, len(list), n):
        yield list[i:i + n]

def get_random_batch_indices(data_count,batch_size):
    list_indices=list(range(0,data_count))
    random.shuffle(list_indices)
    return list(chunks(list_indices, batch_size))

def get_cgan_real_batch(dataset,batch_indices,label):
  dataset_input=dataset[0]
  dataset_condition_info=dataset[1]
  batch_x =[dataset_input[batch_indices],dataset_condition_info[batch_indices]]
  batch_y=np.full(len(batch_indices),label)

  return batch_x,batch_y

In [12]:
def get_cgan_random_input(batch_size,noise_dim,condition_count):
  noise=np.random.normal(0, 1, size=(batch_size, noise_dim))
  condition_info= to_categorical(np.random.randint(0, condition_count, size=batch_size),condition_count)

  return [noise,condition_info]

def get_cgan_fake_batch(generator,batch_size,generator_input):
  batch_x = [generator.predict(generator_input),generator_input[1]]
  batch_y=np.zeros(batch_size)

  return batch_x,batch_y

In [13]:
def train_gan(gan,generator,discriminator,train_x,train_data_count,input_noise_dim,epoch_count, batch_size,
              get_random_input_func,get_real_batch_func,get_fake_batch_func,concatenate_batches_func,condition_count=-1,
              use_one_sided_labels=False,plt_frq=None,plt_example_count=10,example_shape=(64,64,3)):
    iteration_count = int(train_data_count / batch_size)
    
    print('Epochs: ', epoch_count)
    print('Batch size: ', batch_size)
    print('Iterations: ', iteration_count)
    print('')
    
    #Plot generated images
    if plt_frq!=None:
      print('Before training:')
      noise_to_plot = get_random_input_func(plt_example_count, input_noise_dim,condition_count)
      generated_output = generator.predict(noise_to_plot)
      generated_images = generated_output.reshape(plt_example_count, example_shape[0], example_shape[1],  example_shape[2])
      plot_generated_images([generated_images],1,plt_example_count,figsize=(15, 5), epoch=epoch_count)
          
    d_epoch_losses=[]
    g_epoch_losses=[]

    
    for e in range(1, epoch_count+1):
        start_time = time.time()
        avg_d_loss=0
        avg_g_loss=0

        # Training indices are shuffled and grouped into batches
        batch_indices=get_random_batch_indices(train_data_count,batch_size)

        for i in range(iteration_count):
            current_batch_size=len(batch_indices[i])

            # 1. create a batch with real images from the training set
            real_batch_x,real_batch_y=get_real_batch_func(train_x,batch_indices[i],0.9 if use_one_sided_labels else 1)
                        
            # 2. create noise vectors for the generator and generate the images from the noise
            generator_input=get_random_input_func(current_batch_size, input_noise_dim,condition_count)
            fake_batch_x,fake_batch_y=get_fake_batch_func(generator,current_batch_size,generator_input)

            # 3. concatenate real and fake batches into a single batch
            discriminator_batch_x = concatenate_batches_func(real_batch_x, fake_batch_x)
            discriminator_batch_y= np.concatenate((real_batch_y, fake_batch_y))

            # 4. train discriminator
            d_loss = discriminator.train_on_batch(discriminator_batch_x, discriminator_batch_y)
            
            # 5. create noise vectors for the generator
            gan_batch_x = get_random_input_func(current_batch_size, input_noise_dim,condition_count)
            gan_batch_y = np.ones(current_batch_size)    #Flipped labels

            # 6. train generator
            g_loss = gan.train_on_batch(gan_batch_x, gan_batch_y)

            # 7. avg losses
            avg_d_loss+=d_loss*current_batch_size
            avg_g_loss+=g_loss*current_batch_size
            
        avg_d_loss/=train_data_count
        avg_g_loss/=train_data_count

        d_epoch_losses.append(avg_d_loss)
        g_epoch_losses.append(avg_g_loss)

        end_time = time.time()

        print('Epoch: {0} exec_time={1:.1f}s d_loss={2:.3f} g_loss={3:.3f}'.format(e,end_time - start_time,avg_d_loss,avg_g_loss))

        # Update the plots
        if plt_frq!=None and e%plt_frq == 0:
            generated_output = generator.predict(noise_to_plot)
            generated_images = generated_output.reshape(plt_example_count, example_shape[0], example_shape[1],  example_shape[2])
            plot_generated_images([generated_images],1,plt_example_count,figsize=(15, 5), epoch=e)
    
    return d_epoch_losses,g_epoch_losses

# **Data import and manipulation**

In [14]:
#TO DO: move datasets in the repo and change root_folder

train_x, test_x, val_x, train_y, test_y, val_y, train_y_one_hot, test_y_one_hot, val_y_one_hot, input_shape, category_count, labels = data_loader(name=dataset_name,
                                                                                                                                     root_folder='/content/')

ValueError: Sample larger than population or is negative

# **GAN model**

In [None]:
import wandb
from wandb.keras import WandbCallback
#wandb.init(project="my-test-project", entity="nrderus")

In [None]:
!wandb login a32cf68901332ce5f39557dc9f6a8d328f07098b --relogin

In [None]:
patience = 5


wandb.init(project="GAN", entity="nrderus",
  config = {
  "dataset": dataset_name,
  "model": "CVAE",
  "learning_rate": learning_rate,
  "epochs": epoch_count,
  "batch_size": batch_size,
  "patience": patience,
  
})

In [None]:
def build_cdcgan(input_noise_dim,condition_dim):
  input_noise=layers.Input(shape=input_noise_dim, name='input_noise')
  input_condition=layers.Input(shape=condition_dim, name='input_condition')

  input_noise_reshaped=layers.Reshape((1,1,512))(input_noise)
  input_condition_reshaped=layers.Reshape((1,1,2))(input_condition)

  #Generator
  generator_input = layers.Concatenate(name='generator_input')([input_noise_reshaped, input_condition_reshaped])

  prev_layer=layers.Conv2DTranspose(512,8,strides=2,padding='valid')(generator_input)
  prev_layer=layers.BatchNormalization()(prev_layer)
  prev_layer=layers.LeakyReLU(alpha=0.2)(prev_layer)


  prev_layer=layers.Conv2DTranspose(128,3,strides=2,padding='same')(prev_layer)
  prev_layer=layers.BatchNormalization()(prev_layer)
  prev_layer=layers.LeakyReLU(alpha=0.2)(prev_layer)

  prev_layer=layers.Conv2DTranspose(64,3,strides=2,padding='same')(prev_layer)
  prev_layer=layers.BatchNormalization()(prev_layer)
  prev_layer=layers.LeakyReLU(alpha=0.2)(prev_layer)


  prev_layer=layers.Conv2DTranspose(64,3,strides=2,padding='same')(prev_layer)
  prev_layer=layers.BatchNormalization()(prev_layer)
  prev_layer=layers.LeakyReLU(alpha=0.2)(prev_layer)

  generator_output=layers.Conv2DTranspose(3,3,strides=1,padding='same',activation='tanh',name='generator_output')(prev_layer)

  generator = keras.Model([input_noise,input_condition], generator_output, name='generator')

  #Discriminator
  

  discriminator_input_sample = layers.Input(shape=(64,64,3), name='discriminator_input_sample')

  input_condition_dense=layers.Dense(64*64)(input_condition)
  discriminator_input_condition=layers.Reshape((64,64,1))(input_condition_dense)

  discriminator_input = layers.Concatenate(name='discriminator_input')([discriminator_input_sample, discriminator_input_condition])
  prev_layer=layers.Conv2D(64,3,strides=2,padding='same')(discriminator_input)
  prev_layer=layers.BatchNormalization()(prev_layer)
  prev_layer=layers.LeakyReLU(alpha=0.2)(prev_layer)

  prev_layer=layers.Conv2D(128,5,strides=2,padding='same')(prev_layer)
  prev_layer=layers.BatchNormalization()(prev_layer)
  prev_layer=layers.LeakyReLU(alpha=0.2)(prev_layer)

  prev_layer=layers.Conv2D(256,5,strides=2,padding='same')(prev_layer)
  prev_layer=layers.BatchNormalization()(prev_layer)
  prev_layer=layers.LeakyReLU(alpha=0.2)(prev_layer)

  prev_layer=layers.Conv2D(512,3,strides=2,padding='same')(prev_layer)
  prev_layer=layers.BatchNormalization()(prev_layer)
  prev_layer=layers.LeakyReLU(alpha=0.2)(prev_layer)

  prev_layer=layers.Conv2D(1, 4,strides=1,padding='valid',activation='sigmoid')(prev_layer)

  discriminator_output=layers.Reshape((1,),name='discriminator_output')(prev_layer)

  discriminator = keras.Model([discriminator_input_sample,input_condition], discriminator_output, name='discriminator')



  #cDCGAN
  cdcgan = keras.Model(generator.input, discriminator([generator.output,input_condition]),name='cdcgan')
  
  return cdcgan,generator,discriminator


In [None]:


cdcgan,cdcgan_generator,cdcgan_discriminator=build_cdcgan(input_noise_dim,category_count)

In [None]:
cdcgan.summary()

In [None]:
optimizer = keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5)

cdcgan_discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)

cdcgan_discriminator.trainable = False
cdcgan.compile(loss='binary_crossentropy', optimizer=optimizer)

In [None]:


d_epoch_losses,g_epoch_losses=train_gan(cdcgan,
                                        cdcgan_generator,
                                        cdcgan_discriminator,
                                        [val_x,val_y_one_hot],
                                        val_x.shape[0],
                                        input_noise_dim,
                                        epoch_count,
                                        batch_size,
                                        get_cgan_random_input,
                                        get_cgan_real_batch,
                                        get_cgan_fake_batch,
                                        concatenate_cgan_batches,
                                        condition_count=category_count,
                                        use_one_sided_labels=True,
                                        plt_frq=1,
                                        plt_example_count=15)

In [None]:
plot_gan_losses(d_epoch_losses,g_epoch_losses)

In [None]:
digit_label=0

noise = np.random.normal(0, 1, size=(1, input_noise_dim))
digit_label_one_hot=to_categorical(digit_label, category_count).reshape(1,-1)

generated_x = cdcgan_generator.predict([noise,digit_label_one_hot])
digit = generated_x[0].reshape(input_shape)

plt.axis('off')
plt.imshow(digit, cmap='gray')
plt.show()

In [None]:
n = 10 # number of images per digit category

generated_images=[]
for digit_label in range(category_count):
  noise = np.random.normal(0, 1, size=(n, input_noise_dim))
  digit_label_one_hot=to_categorical(np.full(n,digit_label), category_count)
  generated_x = cdcgan_generator.predict([noise,digit_label_one_hot])
  generated_images.append([g.reshape(input_shape) for g in generated_x])

plot_generated_images(generated_images,category_count,n)

In [None]:
wandb.finish(exit_code=0, quiet = True) 