# StackGAN
Author: Vishal Kundar and Yash Mathur

stackgan.ipynb consists of stage 1 and stage 2 code. Run the driver function runmodel() to
execute training. Model weights and images are stored to in directories to check progress.

In [None]:
#Packages

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

import tensorflow.keras.backend as K
import numpy as np
import os
from matplotlib import pyplot as plt
import pickle
import cv2
import time

In [None]:
def load_class_ids_filenames(class_id_path, filename_path):
    with open(class_id_path, 'rb') as file:
        class_id = pickle.load(file, encoding='latin1')

    with open(filename_path, 'rb') as file:
        filename = pickle.load(file, encoding='latin1')

    return class_id, filename

def load_text_embeddings(text_embeddings):
    with open(text_embeddings, 'rb') as file:
        embeds = pickle.load(file, encoding='latin1')
        embeds = np.array(embeds)

    return embeds

def load_images(pickle_file):
    #Loading images from pickle file
    x = []
    with open(pickle_file, 'rb') as f_in:
        images = pickle.load(f_in)
        
    return images   

def parse_function(self, image_path, embeddings, bounding_box):
    image.set_shape([64, 64, 3])
    image = (image - 127.5) / 127.5

    embedding_index = np.random.randint(0, embeddings.shape[0] - 1)
    embedding = embeddings[embedding_index]
    return image, embedding

def load_data(filename_path, class_id_path, embeddings_path, pickle_file):
    """
    Loads the data in pickle file along with class id and embeddings
    """
    class_id, filenames = load_class_ids_filenames(class_id_path, filename_path)
    embeddings = load_text_embeddings(embeddings_path)
    y, embeds = [], []

    for i, filename in enumerate(filenames):
        try:
            e = embeddings[i, :, :]
            embed_index = np.random.randint(0, e.shape[0] - 1)
            embed = e[embed_index, :]

            y.append(class_id[i])
            embeds.append(embed)

        except Exception as e:
            print(f'{e}')

    x = np.array(load_images(pickle_file))
    y = np.array(y)
    embeds = np.array(embeds)
    
    return x, y, embeds

In [None]:
#Data directory for both models
#Birds
path = "../input/gandata20/birds"
train_path = path + "/train"
test_path = path + "/test"
embedding_train = train_path + "/char-CNN-RNN-embeddings.pickle"
embedding_test = test_path + "/char-CNN-RNN-embeddings.pickle"
filename_train = train_path + "/filenames.pickle"
filename_test = test_path + "/filenames.pickle"
class_id_train = train_path + "/class_info.pickle"
class_id_test = test_path + "/class_info.pickle"
pickle_train_low = train_path + "/64images.pickle"
pickle_test_low = test_path + "/64images.pickle"
pickle_train_high = train_path + "/256images.pickle"
pickle_test_high = test_path + "/256images.pickle"

In [None]:
def KL_loss(y_true, y_pred): 
  mean = y_pred[:, :128]
  logsigma = y_pred[:, 128:]
  loss = -logsigma + 0.5*(-1 + K.exp(2.0*logsigma) + K.square(mean))
  loss = K.mean(loss)
  return loss

class ConditioningAugmentation(tf.keras.Model): 
  def __init__(self):
    super(ConditioningAugmentation, self).__init__()
    self.dense = tf.keras.layers.Dense(units = 256)

  def call(self, E):
    X = self.dense(E)
    phi = tf.nn.leaky_relu(X)
    mean = phi[:, :128]
    std = K.exp(phi[:, 128:])
    epsilon = K.random_normal(shape = K.constant((mean.shape[1], ), dtype = 'int32'))
    C = mean + epsilon*std
    return C, phi

class EmbeddingCompressor(tf.keras.Model):
  def __init__(self):
    super(EmbeddingCompressor, self).__init__()
    self.dense = tf.keras.layers.Dense(units = 128)

  def call(self, E):
    X = self.dense(E)
    return tf.nn.relu(X)

class Stage1Generator(tf.keras.Model):
  def __init__(self):
    super(Stage1Generator, self).__init__()
    self.canet = ConditioningAugmentation()
    self.concat = tf.keras.layers.Concatenate(axis = 1)
    self.dense = tf.keras.layers.Dense(units = 128*8*4*4, kernel_initializer = tf.random_normal_initializer(stddev = 0.02))
    self.reshape = tf.keras.layers.Reshape(target_shape = (4, 4, 128*8), input_shape = (128*8*4*4, ))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.deconv1 = tf.keras.layers.Conv2DTranspose(filters = 512, kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.random_normal_initializer(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.deconv2 = tf.keras.layers.Conv2DTranspose(filters = 256, kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.random_normal_initializer(stddev = 0.02))
    self.batchnorm3 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.deconv3 = tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.random_normal_initializer(stddev = 0.02))
    self.batchnorm4 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.deconv4 = tf.keras.layers.Conv2DTranspose(filters = 3, kernel_size = 4, strides = (2, 2), padding = "same", kernel_initializer = tf.random_normal_initializer(stddev = 0.02))

  def call(self, inputs): 
    E, Z = inputs
    C, phi = self.canet(E)

    gen_input = self.concat([C, Z])
    X = self.dense(gen_input)
    X = self.reshape(X)
    #X = self.batchnorm1(X)
    X = tf.nn.relu(X)

    X = self.deconv1(X)
    X = self.batchnorm1(X)
    X = tf.nn.relu(X)

    X = self.deconv2(X)
    X = self.batchnorm2(X)
    X = tf.nn.relu(X)

    X = self.deconv3(X)
    X = self.batchnorm3(X)
    X = tf.nn.relu(X)

    X = self.deconv4(X)
    return tf.nn.tanh(X), phi

class Stage1Discriminator(tf.keras.Model): 
  def __init__(self):
    super(Stage1Discriminator, self).__init__()
    self.conv1 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.conv2 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv3 = tf.keras.layers.Conv2D(filters = 256, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv4 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm3 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.embed = EmbeddingCompressor()
    self.reshape = tf.keras.layers.Reshape(target_shape = (1, 1, 128))
    self.concat = tf.keras.layers.Concatenate()
    self.conv5 = tf.keras.layers.Conv2D(filters = 64*8, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm4 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv6 = tf.keras.layers.Conv2D(filters = 1, kernel_size = 4, strides = 1, padding = "valid", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))

  def call(self, inputs):
    I, E = inputs
    X = self.conv1(I)
    X = tf.nn.leaky_relu(X)

    X = self.conv2(X)
    X = self.batchnorm1(X)
    X = tf.nn.leaky_relu(X)

    X = self.conv3(X)
    X = self.batchnorm2(X)
    X = tf.nn.leaky_relu(X)

    X = self.conv4(X)
    X = self.batchnorm3(X)
    X = tf.nn.leaky_relu(X)

    T = self.embed(E)
    T = self.reshape(T)
    T = tf.tile(T, (1, 4, 4, 1))
    merged_input = self.concat([X, T])

    Y = self.conv5(merged_input)
    Y = self.batchnorm4(Y)
    Y = tf.nn.leaky_relu(Y)

    Y = self.conv6(Y)
    return tf.squeeze(Y)

class Stage1Model(tf.keras.Model):
  def __init__(self):
    super(Stage1Model, self).__init__()
    self.stage1_generator = Stage1Generator()
    self.stage1_discriminator = Stage1Discriminator()

  def train(self, train_ds, batch_size = 128, num_epochs = 500, z_dim = 100, c_dim = 128, stage1_generator_lr = 0.0004, stage1_discriminator_lr = 0.0004):
    generator_optimizer = tf.keras.optimizers.Adam(lr = stage1_generator_lr, beta_1 = 0.5, beta_2 = 0.999)
    discriminator_optimizer = tf.keras.optimizers.Adam(lr = stage1_discriminator_lr, beta_1 = 0.5, beta_2 = 0.999)

    for epoch in range(num_epochs):
      print("Epoch %d/%d:\n ["%(epoch + 1, num_epochs), end = "")
      start_time = time.time()
      if epoch % 150 == 0:
        K.set_value(generator_optimizer.learning_rate, generator_optimizer.learning_rate / 2)
        K.set_value(discriminator_optimizer.learning_rate, discriminator_optimizer.learning_rate / 2)

      generator_loss_log = []
      discriminator_loss_log = []
      num_batches = int(x_train.shape[0] / batch_size)
      for i in range(num_batches):
        image_batch = x_train[i * batch_size:(i+1) * batch_size]
        image_batch = (image_batch - 127.5) / 127.5 
        embedding_batch = train_embeds[i * batch_size:(i+1) * batch_size]
        z_noise = tf.random.normal((batch_size, z_dim))

        mismatched_images = tf.roll(image_batch, shift = 1, axis = 0)

        real_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.9, maxval = 1.0)
        fake_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)
        mismatched_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)

        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
          fake_images, phi = self.stage1_generator([embedding_batch, z_noise])

          real_logits = self.stage1_discriminator([image_batch, embedding_batch])
          fake_logits = self.stage1_discriminator([fake_images, embedding_batch])
          mismatched_logits = self.stage1_discriminator([mismatched_images, embedding_batch])

          l_sup = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, fake_logits))
          l_klreg = KL_loss(tf.random.normal((phi.shape[0], phi.shape[1])), phi)
          generator_loss = l_sup + 2.0*l_klreg

          l_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, real_logits))
          l_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fake_labels, fake_logits))
          l_mismatched = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(mismatched_labels, mismatched_logits))
          discriminator_loss = 0.5*tf.add(l_real, 0.5*tf.add(l_fake, l_mismatched))

        generator_gradients = generator_tape.gradient(generator_loss, self.stage1_generator.trainable_variables)
        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, self.stage1_discriminator.trainable_variables)

        generator_optimizer.apply_gradients(zip(generator_gradients, self.stage1_generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.stage1_discriminator.trainable_variables))

        generator_loss_log.append(generator_loss)
        discriminator_loss_log.append(discriminator_loss)

      end_time = time.time()

      if epoch % 1 == 0:
        epoch_time = end_time - start_time
        template = "] - generator_loss: {:.4f} - discriminator_loss: {:.4f} - epoch_time: {:.2f} s"
        print(template.format(tf.reduce_mean(generator_loss_log), tf.reduce_mean(discriminator_loss_log), epoch_time))

      if (epoch + 1) % 50 == 0 or epoch == num_epochs - 1:
        temp_batch_size = 1
        temp_z_noise = tf.random.normal((temp_batch_size, z_dim))
        temp_embedding_batch = test_embeds[0:temp_batch_size]
        fake_images, _ = self.stage1_generator([temp_embedding_batch, temp_z_noise])
        for i, image in enumerate(fake_images):
          image = 127.5*image + 127.5
          image = image.numpy().astype('uint8')
          path = "./gen1_" + str(epoch + 1)  
          cv2.imwrite(path + "_%d.png"%(i), image)

        self.stage1_generator.save_weights("./stage1_generator_" + str(epoch + 1) + ".ckpt")
        self.stage1_discriminator.save_weights("./stage1_discriminator_" + str(epoch + 1) + ".ckpt")

# Stage 2

In [None]:
class ResidualBlock(tf.keras.layers.Layer):
  def __init__(self):
    super(ResidualBlock, self).__init__()
    self.conv1 = tf.keras.layers.Conv2D(filters = 128*4, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv2 = tf.keras.layers.Conv2D(filters = 128*4, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)

    def call(self, I):
      X = self.conv1(I)
      X = self.batchnorm1(X)
      X = tf.nn.relu(X)

      X = self.conv2(X)
      X = self.batchnorm2(X)
      X = tf.nn.relu(X)
      X = tf.keras.layers.Add()([X, I])
      X = tf.nn.relu(X)
      return X

class Stage2Generator(tf.keras.Model):
  def __init__(self):
    super(Stage2Generator, self).__init__()
    self.canet = ConditioningAugmentation()
    self.conv1 = tf.keras.layers.Conv2D(128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.conv2 = tf.keras.layers.Conv2D(256, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv3 = tf.keras.layers.Conv2D(512, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv4 = tf.keras.layers.Conv2D(512, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm3 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.resblock1 = ResidualBlock()
    self.resblock2 = ResidualBlock()
    self.resblock3 = ResidualBlock()
    self.resblock4 = ResidualBlock()
    self.upsamp1 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv5 = tf.keras.layers.Conv2D(256, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm4 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.upsamp2 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv6 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm5 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.upsamp3 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv7 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm6 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.upsamp4 = tf.keras.layers.UpSampling2D(size = (2, 2))
    self.conv8 = tf.keras.layers.Conv2D(filters = 32, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm7 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv9 = tf.keras.layers.Conv2D(filters = 3, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
  
  def call(self, inputs):
    E, I = inputs
    C, phi = self.canet(E)

    X = self.conv1(I)
    X = tf.nn.relu(X)
    
    X = self.conv2(X)
    X = self.batchnorm1(X)
    X = tf.nn.relu(X)

    X = self.conv3(X)
    X = self.batchnorm2(X)
    X = tf.nn.relu(X)

    C = K.expand_dims(C, axis = 1)
    C = K.expand_dims(C, axis = 1)
    C = K.tile(C, [1, 16, 16, 1])
    J = K.concatenate([C, X], axis = 3)

    X = self.conv4(X)
    X = self.batchnorm3(X)
    X = tf.nn.relu(X)

    X = self.resblock1(X)
    X = self.resblock2(X)
    X = self.resblock3(X)
    X = self.resblock4(X)

    X = self.upsamp1(X)
    X = self.conv5(X)
    X = self.batchnorm4(X)
    X = tf.nn.relu(X)
    
    X = self.upsamp2(X)
    X = self.conv6(X)
    X = self.batchnorm5(X)
    X = tf.nn.relu(X)
    
    X = self.upsamp3(X)
    X = self.conv7(X)
    X = self.batchnorm6(X)
    X = tf.nn.relu(X)
    
    X = self.upsamp4(X)
    X = self.conv8(X)
    X = self.batchnorm7(X)
    X = tf.nn.relu(X)
    
    X = self.conv9(X)
    return tf.nn.tanh(X), phi

class Stage2Discriminator(tf.keras.Model):
  def __init__(self):
    super(Stage2Discriminator, self).__init__()
    self.embed = EmbeddingCompressor()
    self.reshape = tf.keras.layers.Reshape(target_shape = (1, 1, 128))
    self.conv1 = tf.keras.layers.Conv2D(filters = 64, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.conv2 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm1 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv3 = tf.keras.layers.Conv2D(filters = 256, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm2 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv4 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm3 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv5 = tf.keras.layers.Conv2D(filters = 1024, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm4 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv6 = tf.keras.layers.Conv2D(filters = 2048, kernel_size = 4, strides = 2, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm5 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv7 = tf.keras.layers.Conv2D(filters = 1024, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm6 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv8 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm7 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv9 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm8 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv10 = tf.keras.layers.Conv2D(filters = 128, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm9 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv11 = tf.keras.layers.Conv2D(filters = 512, kernel_size = 3, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm10 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv12 = tf.keras.layers.Conv2D(filters = 64*8, kernel_size = 1, strides = 1, padding = "same", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))
    self.batchnorm11 = tf.keras.layers.BatchNormalization(axis = -1, momentum = 0.99)
    self.conv13 = tf.keras.layers.Conv2D(filters = 1, kernel_size = 4, strides = 1, padding = "valid", kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev = 0.02))

  def call(self, inputs):
    I, E = inputs
    T = self.embed(E)
    T = self.reshape(T)
    T = tf.tile(T, (1, 4, 4, 1))

    X = self.conv1(I)
    X = tf.nn.leaky_relu(X)

    X = self.conv2(X)
    X = self.batchnorm1(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv3(X)
    X = self.batchnorm2(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv4(X)
    X = self.batchnorm3(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv5(X)
    X = self.batchnorm4(X)
    X = tf.nn.leaky_relu(X)
   
    X = self.conv6(X)
    X = self.batchnorm5(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv7(X)
    X = self.batchnorm6(X)
    X = tf.nn.leaky_relu(X)
    
    X = self.conv8(X)
    X = self.batchnorm7(X)

    Y = self.conv9(X)
    Y = self.batchnorm8(Y)
    Y = tf.nn.leaky_relu(Y)

    Y = self.conv10(Y)
    Y = self.batchnorm9(Y)
    Y = tf.nn.leaky_relu(Y)

    Y = self.conv11(Y)
    Y = self.batchnorm10(Y)

    A = tf.keras.layers.Add()([X, Y])
    A = tf.nn.leaky_relu(A)

    merged_input = tf.keras.layers.concatenate([A, T])

    Z = self.conv12(merged_input)
    Z = self.batchnorm11(Z)
    Z = tf.nn.leaky_relu(Z)
    
    Z = self.conv13(Z)
    return tf.squeeze(Z)

class Stage2Model(tf.keras.Model):
  def __init__(self):
    super(Stage2Model, self).__init__()
    self.stage1_generator = Stage1Generator()
    self.stage1_generator.compile(loss = "mse", optimizer = "adam")
    self.stage1_generator.load_weights("../input/allweights/stage1_generator_300.ckpt").expect_partial()
    
    self.stage2_generator = Stage2Generator()
    self.stage2_discriminator = Stage2Discriminator()
    self.stage2_generator.compile(loss = "mse", optimizer = "adam")
    self.stage2_generator.load_weights("../input/allweights/stage2_generator_101.ckpt").expect_partial()
    self.stage2_discriminator.compile(loss = "mse", optimizer = "adam")
    self.stage2_discriminator.load_weights("../input/allweights/stage2_discriminator_101.ckpt").expect_partial()
    
  def train(self, train_ds, batch_size = 64, num_epochs = 800, z_dim = 100, stage1_generator_lr = 0.0001, stage1_discriminator_lr = 0.0001):
    generator_optimizer = tf.keras.optimizers.Adam(lr = stage1_generator_lr, beta_1 = 0.5, beta_2 = 0.999)
    discriminator_optimizer = tf.keras.optimizers.Adam(lr = stage1_discriminator_lr, beta_1 = 0.5, beta_2 = 0.999)
    
    for epoch in range(num_epochs):
      print("Epoch %d/%d:\n ["%(epoch + 1, num_epochs), end = "")
      start_time = time.time()
      if epoch % 100 == 0:
        K.set_value(generator_optimizer.learning_rate, generator_optimizer.learning_rate / 2)
        K.set_value(discriminator_optimizer.learning_rate, discriminator_optimizer.learning_rate / 2)
    
      generator_loss_log = []
      discriminator_loss_log = []
      num_batches = int(x_train.shape[0] / batch_size)
      for i in range(num_batches):
        if i % 5 == 0:
          print("=", end = "")
        
        image_batch = x_train[i * batch_size:(i+1) * batch_size]
        hr_image_batch = (image_batch - 127.5) / 127.5 
        embedding_batch = train_embeds[i * batch_size:(i+1) * batch_size]
        z_noise = tf.random.normal((batch_size, z_dim))

        mismatched_images = tf.roll(hr_image_batch, shift = 1, axis = 0)

        real_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.9, maxval = 1.0)
        fake_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)
        mismatched_labels = tf.random.uniform(shape = (batch_size, ), minval = 0.0, maxval = 0.1)

        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
          lr_fake_images, _ = self.stage1_generator([embedding_batch, z_noise])
          hr_fake_images, phi = self.stage2_generator([embedding_batch, lr_fake_images])
          real_logits = self.stage2_discriminator([hr_image_batch, embedding_batch])
          fake_logits = self.stage2_discriminator([hr_fake_images, embedding_batch])
          mismatched_logits = self.stage2_discriminator([mismatched_images, embedding_batch])
          
          l_sup = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, fake_logits))
          l_klreg = KL_loss(tf.random.normal((phi.shape[0], phi.shape[1])), phi)
          generator_loss = l_sup + 2.0*l_klreg
          
          l_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(real_labels, real_logits))
          l_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(fake_labels, fake_logits))
          l_mismatched = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(mismatched_labels, mismatched_logits))
          discriminator_loss = 0.5*tf.add(l_real, 0.5*tf.add(l_fake, l_mismatched))
        
        generator_gradients = generator_tape.gradient(generator_loss, self.stage2_generator.trainable_variables)
        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, self.stage2_discriminator.trainable_variables)
        
        generator_optimizer.apply_gradients(zip(generator_gradients, self.stage2_generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients, self.stage2_discriminator.trainable_variables))
        
        generator_loss_log.append(generator_loss)
        discriminator_loss_log.append(discriminator_loss)
        
      end_time = time.time()

      if epoch % 1 == 0:
        epoch_time = end_time - start_time
        template = "] - generator_loss: {:.4f} - discriminator_loss: {:.4f} - epoch_time: {:.2f} s"
        print(template.format(tf.reduce_mean(generator_loss_log), tf.reduce_mean(discriminator_loss_log), epoch_time))

      if epoch % 50 == 0 or epoch == num_epochs - 1:
        save_path = "./"
        temp_batch_size = 5
        temp_z_noise = tf.random.normal((temp_batch_size, z_dim))
        temp_embedding_batch = test_embeds[0:temp_batch_size]
        lr_fake_images, _ = self.stage1_generator([temp_embedding_batch, temp_z_noise])
        hr_fake_images, _ = self.stage2_generator([temp_embedding_batch, lr_fake_images])
        for i, image in enumerate(hr_fake_images):
          image = 127.5*image + 127.5
          image = image.numpy().astype('uint8')
          cv2.imwrite(save_path + "gen2_%d.png"%(i), image)
        self.stage2_generator.save_weights("./stage2_generator_" + str(epoch + 1) + ".ckpt")
        self.stage2_discriminator.save_weights("./stage2_discriminator_" + str(epoch + 1) + ".ckpt")

In [None]:
x_train, y_train, train_embeds = load_data(filename_path=filename_train, class_id_path=class_id_train, embeddings_path=embedding_train, pickle_file = pickle_train_high)
x_test, y_test, test_embeds = load_data(filename_path=filename_test, class_id_path=class_id_test, embeddings_path=embedding_test, pickle_file = pickle_test_high)

In [None]:
def run_model():
    ####################
    #STAGE 1
    ####################
    """
    model = Stage1Model() 
    #model.train(0, num_epochs = 1)
    model.stage1_generator.load_weights("../input/allweights/stage1_generator_300.ckpt").expect_partial()
    model.stage1_discriminator.load_weights("../input/allweights/stage1_discriminator_300.ckpt").expect_partial()
    model.train(0)
    """
    ####################
    #STAGE 2
    ####################
    model = Stage2Model()
    model.train(0)

In [None]:
run_model()

In [None]:
def stage1_generateImage(z_dim=100, samples=1):
    model = Stage1Model()
    #model.train(0, num_epochs = 1)
    model.stage1_generator.load_weights("../input/allweights/stage1_generator_300.ckpt").expect_partial()
    model.stage1_discriminator.load_weights("../input/allweights/stage1_discriminator_300.ckpt").expect_partial()
    temp_batch_size = samples
    temp_z_noise = tf.random.normal((temp_batch_size, z_dim))
    
    temp_embedding_batch = test_embeds[0:temp_batch_size]
    print("Embeddings: ", temp_embedding_batch)
    
    fake_images, _ = model.stage1_generator([temp_embedding_batch, temp_z_noise])
    for i, image in enumerate(fake_images):
        image = 127.5*image + 127.5
        image = image.numpy().astype('uint8')
        
        img = plt.figure()
        ax = img.add_subplot(1,1,1)
        ax.imshow(image)  
        #path = "./gen1_results"
        #cv2.imwrite(path + "_%d.png"%(i), image)

In [None]:
 #stage1_generateImage(samples=5)