# Synthesizing high-quality images from text descriptions

## Implementation of "Stage 1 " of **StackGAN**

### Stage I of StackGAN 

#### 1- takes input as text, 

#### 2- convert the text to embedding using our pre-trained character level embedding. 

#### 3- Then, we give this embedding to Conditional Augmentation (CA) and 

#### 4- then to Stage I Generator which gives us low-resolution 64*64 images. 

In [49]:
# ==============================================================================
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""StackGAN.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import pickle
import random
import time

import numpy as np
import pandas as pd
import tensorflow as tf

assert tf.__version__.startswith('2')

import PIL
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow.keras.backend as K
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import LeakyReLU, BatchNormalization, ReLU, Activation
from tensorflow.keras.layers import UpSampling2D, Conv2D, Concatenate, Dense, concatenate
from tensorflow.keras.layers import Flatten, Lambda, Reshape, ZeroPadding2D, add
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import warnings
warnings.simplefilter("ignore")
warnings.filterwarnings( "ignore", module = "matplotlib\..*" )


############################################################
# Conditioning Augmentation Network
############################################################


Computer doesn’t understand words, but it can represent the words in terms of something it does “understand”. That’s the “text embedding”, and it’s used as the c




In [50]:
# conditioned by the text.
def conditioning_augmentation(x):
	"""The mean_logsigma passed as argument is converted into the text conditioning variable.

	Args:
		x: The output of the text embedding passed through a FC layer with LeakyReLU non-linearity.

	Returns:
	 	c: The text conditioning variable after computation.
	"""
	mean = x[:, :128]
	log_sigma = x[:, 128:]

	stddev = tf.math.exp(log_sigma)
	epsilon = K.random_normal(shape=K.constant((mean.shape[1], ), dtype='int32'))
	c = mean + stddev * epsilon
	return c

def build_ca_network():
	"""Builds the conditioning augmentation network.
	"""
	input_layer1 = Input(shape=(1024,)) #size of the vocabulary in the text data
	mls = Dense(256)(input_layer1)
	mls = LeakyReLU(alpha=0.2)(mls)
	ca = Lambda(conditioning_augmentation)(mls)
	return Model(inputs=[input_layer1], outputs=[ca]) 


############################################################
# Stage 1 Generator Network 
############################################################

1. The generator is fed with the text captions in the form of Embedding vectors which will be used to condition its generation of features.
2. A vector with random noise.


In [51]:


def UpSamplingBlock(x, num_kernels):
	"""An Upsample block with Upsampling2D, Conv2D, BatchNormalization and a ReLU activation.

	Args:
		x: The preceding layer as input.
		num_kernels: Number of kernels for the Conv2D layer.

	Returns:
		x: The final activation layer after the Upsampling block.
	"""
	x = UpSampling2D(size=(2,2))(x)
	x = Conv2D(num_kernels, kernel_size=(3,3), padding='same', strides=1, use_bias=False,
				kernel_initializer='he_uniform')(x)
	x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x) #prevent from mode collapse
	x = ReLU()(x)
	return x


def build_stage1_generator():

	input_layer1 = Input(shape=(1024,))
	ca = Dense(256)(input_layer1)
	ca = LeakyReLU(alpha=0.2)(ca)

	# Obtain the conditioned text
	c = Lambda(conditioning_augmentation)(ca)

	input_layer2 = Input(shape=(100,))
	concat = Concatenate(axis=1)([c, input_layer2]) 

	x = Dense(16384, use_bias=False)(concat) 
	x = ReLU()(x)
	x = Reshape((4, 4, 1024), input_shape=(16384,))(x)

	x = UpSamplingBlock(x, 512) 
	x = UpSamplingBlock(x, 256)
	x = UpSamplingBlock(x, 128)
	x = UpSamplingBlock(x, 64)   # upsampled our image to 64*64*3 

	x = Conv2D(3, kernel_size=3, padding='same', strides=1, use_bias=False,
				kernel_initializer='he_uniform')(x)
	x = Activation('tanh')(x)

	stage1_gen = Model(inputs=[input_layer1, input_layer2], outputs=[x, ca]) 
	return stage1_gen



In [52]:
generator = build_stage1_generator()
generator.summary()

Model: "model_68"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_128 (InputLayer)         [(None, 1024)]       0           []                               
                                                                                                  
 dense_67 (Dense)               (None, 256)          262400      ['input_128[0][0]']              
                                                                                                  
 leaky_re_lu_101 (LeakyReLU)    (None, 256)          0           ['dense_67[0][0]']               
                                                                                                  
 lambda_26 (Lambda)             (None, 128)          0           ['leaky_re_lu_101[0][0]']        
                                                                                           


############################################################
# Stage 1 Discriminator Network
############################################################	


In [53]:
def ConvBlock(x, num_kernels, kernel_size=(4,4), strides=2, activation=True):
	"""A ConvBlock with a Conv2D, BatchNormalization and LeakyReLU activation.

	Args:
		x: The preceding layer as input.
		num_kernels: Number of kernels for the Conv2D layer.

	Returns:
		x: The final activation layer after the ConvBlock block.
	"""
	x = Conv2D(num_kernels, kernel_size=kernel_size, padding='same', strides=strides, use_bias=False,
				kernel_initializer='he_uniform')(x)
	x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
	
	if activation:
		x = LeakyReLU(alpha=0.2)(x)
	return x


def build_embedding_compressor():
    """Build embedding compressor model
    """
    input_layer1 = Input(shape=(1024,)) 
    x = Dense(128)(input_layer1)
    x = ReLU()(x)

    model = Model(inputs=[input_layer1], outputs=[x])
    return model

# the discriminator is fed with two inputs, the feature from Generator and the text embedding
def build_stage1_discriminator():
	"""Builds the Stage 1 Discriminator that uses the 64x64 resolution images from the generator
	and the compressed and spatially replicated embedding.

	Returns:
		Stage 1 Discriminator Model for StackGAN.
	"""
	input_layer1 = Input(shape=(64, 64, 3))  

	x = Conv2D(64, kernel_size=(4,4), strides=2, padding='same', use_bias=False,
				kernel_initializer='he_uniform')(input_layer1)
	x = LeakyReLU(alpha=0.2)(x)

	x = ConvBlock(x, 128)
	x = ConvBlock(x, 256)
	x = ConvBlock(x, 512)

	# Obtain the compressed and spatially replicated text embedding
	input_layer2 = Input(shape=(4, 4, 128)) #2nd input to discriminator, text embedding
	concat = concatenate([x, input_layer2])

	x1 = Conv2D(512, kernel_size=(1,1), padding='same', strides=1, use_bias=False,
				kernel_initializer='he_uniform')(concat)
	x1 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
	x1 = LeakyReLU(alpha=0.2)(x)

	# Flatten and add a FC layer to predict.
	x1 = Flatten()(x1)
	x1 = Dense(1)(x1)
	x1 = Activation('sigmoid')(x1)

	stage1_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x1])  
	return stage1_dis


In [54]:
discriminator = build_stage1_discriminator()
discriminator.summary()

Model: "model_69"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_130 (InputLayer)         [(None, 64, 64, 3)]  0           []                               
                                                                                                  
 conv2d_150 (Conv2D)            (None, 32, 32, 64)   3072        ['input_130[0][0]']              
                                                                                                  
 leaky_re_lu_102 (LeakyReLU)    (None, 32, 32, 64)   0           ['conv2d_150[0][0]']             
                                                                                                  
 conv2d_151 (Conv2D)            (None, 16, 16, 128)  131072      ['leaky_re_lu_102[0][0]']        
                                                                                           


############################################################
# Stage 1 Adversarial Model  (Building a GAN)
############################################################

Generator and discriminator are stacked together. Output of the former is the input of the latter.

In [55]:
# Building GAN with Generator and Discriminator

def build_adversarial(generator_model, discriminator_model):
	"""Stage 1 Adversarial model.

	Args:
		generator_model: Stage 1 Generator Model
		discriminator_model: Stage 1 Discriminator Model

	Returns:
		Adversarial Model.
	"""
	input_layer1 = Input(shape=(1024,))  
	input_layer2 = Input(shape=(100,)) 
	input_layer3 = Input(shape=(4, 4, 128)) 

	x, ca = generator_model([input_layer1, input_layer2]) #text,noise

	discriminator_model.trainable = False 

	probabilities = discriminator_model([x, input_layer3]) 
	adversarial_model = Model(inputs=[input_layer1, input_layer2, input_layer3], outputs=[probabilities, ca])
	return adversarial_model



In [56]:
ganstage1 = build_adversarial(generator, discriminator)
ganstage1.summary()

Model: "model_70"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_132 (InputLayer)         [(None, 1024)]       0           []                               
                                                                                                  
 input_133 (InputLayer)         [(None, 100)]        0           []                               
                                                                                                  
 model_68 (Functional)          [(None, 64, 64, 3),  10270400    ['input_132[0][0]',              
                                 (None, 256)]                     'input_133[0][0]']              
                                                                                                  
 input_134 (InputLayer)         [(None, 4, 4, 128)]  0           []                        

############################################################
# Train Utilities
############################################################


In [57]:

def checkpoint_prefix():
	checkpoint_dir = './training_checkpoints'
	checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')

	return checkpoint_prefix

def adversarial_loss(y_true, y_pred):
	mean = y_pred[:, :128]
	ls = y_pred[:, 128:]
	loss = -ls + 0.5 * (-1 + tf.math.exp(2.0 * ls) + tf.math.square(mean))
	loss = K.mean(loss)
	return loss

def normalize(input_image, real_image):
	input_image = (input_image / 127.5) - 1
	real_image = (real_image / 127.5) - 1

	return input_image, real_image

def load_class_ids_filenames(class_id_path, filename_path):
	with open(class_id_path, 'rb') as file:
		class_id = pickle.load(file, encoding='latin1')

	with open(filename_path, 'rb') as file:
		filename = pickle.load(file, encoding='latin1')

	return class_id, filename

def load_text_embeddings(text_embeddings):
	with open(text_embeddings, 'rb') as file:
		embeds = pickle.load(file, encoding='latin1')
		embeds = np.array(embeds)

	return embeds

def load_bbox(data_path):
	bbox_path = data_path + '/bounding_boxes.txt'
	image_path = data_path + '/images.txt'
	bbox_df = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int)
	filename_df = pd.read_csv(image_path, delim_whitespace=True, header=None)

	filenames = filename_df[1].tolist()
	bbox_dict = {i[:-4]:[] for i in filenames[:2]}

	for i in range(0, len(filenames)):
		bbox = bbox_df.iloc[i][1:].tolist()
		dict_key = filenames[i][:-4]
		bbox_dict[dict_key] = bbox

	return bbox_dict

def load_images(image_path, bounding_box, size):
	"""Crops the image to the bounding box and then resizes it.
	"""
	image = Image.open(image_path).convert('RGB')
	w, h = image.size
	if bounding_box is not None:
		r = int(np.maximum(bounding_box[2], bounding_box[3]) * 0.75)
		c_x = int((bounding_box[0] + bounding_box[2]) / 2)
		c_y = int((bounding_box[1] + bounding_box[3]) / 2)
		y1 = np.maximum(0, c_y - r)
		y2 = np.minimum(h, c_y + r)
		x1 = np.maximum(0, c_x - r)
		x2 = np.minimum(w, c_x + r)
		image = image.crop([x1, y1, x2, y2])

	image = image.resize(size, PIL.Image.BILINEAR)
	return image

def load_data(filename_path, class_id_path, dataset_path, embeddings_path, size):
	"""Loads the Dataset.
	"""
	data_dir = "./data/birds"
	train_dir = data_dir + "/train"
	test_dir = data_dir + "/test"
	embeddings_path_train = train_dir + "/char-CNN-RNN-embeddings.pickle"
	embeddings_path_test = test_dir + "/char-CNN-RNN-embeddings.pickle"
	filename_path_train = train_dir + "/filenames.pickle"
	filename_path_test = test_dir + "/filenames.pickle"
	class_id_path_train = train_dir + "/class_info.pickle"
	class_id_path_test = test_dir + "/class_info.pickle"
	dataset_path = "./data/CUB_200_2011"
	class_id, filenames = load_class_ids_filenames(class_id_path, filename_path)
	embeddings = load_text_embeddings(embeddings_path)
	bbox_dict = load_bbox(dataset_path)

	x, y, embeds = [], [], []

	for i, filename in enumerate(filenames):
		bbox = bbox_dict[filename]

		try:	
			image_path = f'{dataset_path}/images/{filename}.jpg'
			image = load_images(image_path, bbox, size)
			e = embeddings[i, :, :]
			embed_index = np.random.randint(0, e.shape[0] - 1)
			embed = e[embed_index, :]

			x.append(np.array(image))
			y.append(class_id[i])
			embeds.append(embed)

		except Exception as e:
			print(f'{e}')
	
	x = np.array(x)
	y = np.array(y)
	embeds = np.array(embeds)
	
	return x, y, embeds

def save_image(file, save_path):
	"""Saves the image at the specified file path.
	"""
	image = plt.figure()
	ax = image.add_subplot(1,1,1)
	ax.imshow(file)
	ax.axis("off")
	plt.savefig(save_path)




In [58]:

############################################################
# StackGAN class
############################################################

class StackGanStage1(object):
  """StackGAN Stage 1 class."""

  data_dir = "./data/birds"
  train_dir = data_dir + "/train"
  test_dir = data_dir + "/test"
  embeddings_path_train = train_dir + "/char-CNN-RNN-embeddings.pickle"
  embeddings_path_test = test_dir + "/char-CNN-RNN-embeddings.pickle"
  filename_path_train = train_dir + "/filenames.pickle"
  filename_path_test = test_dir + "/filenames.pickle"
  class_id_path_train = train_dir + "/class_info.pickle"
  class_id_path_test = test_dir + "/class_info.pickle"
  dataset_path = "./data/CUB_200_2011"
  def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage1_generator_lr=0.0002, stage1_discriminator_lr=0.0002):
	  self.epochs = epochs
	  self.z_dim = z_dim
	  self.enable_function = enable_function
	  self.stage1_generator_lr = stage1_generator_lr
	  self.stage1_discriminator_lr = stage1_discriminator_lr
	  self.image_size = 64
	  self.conditioning_dim = 128
	  self.batch_size = batch_size
        
	  self.stage1_generator_optimizer = Adam(learning_rate=stage1_generator_lr, beta_1=0.5, beta_2=0.999)
	  self.stage1_discriminator_optimizer = Adam(learning_rate=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)
        
	  self.stage1_generator = build_stage1_generator()
	  self.stage1_generator.compile(loss='mse', optimizer=self.stage1_generator_optimizer)

	  self.stage1_discriminator = build_stage1_discriminator()
	  self.stage1_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage1_discriminator_optimizer)

	  self.ca_network = build_ca_network()
	  self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')

	  self.embedding_compressor = build_embedding_compressor()
	  self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')

	  self.stage1_adversarial = build_adversarial(self.stage1_generator, self.stage1_discriminator)
	  self.stage1_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage1_generator_optimizer)

	  self.checkpoint1 = tf.train.Checkpoint(
        	generator_optimizer=self.stage1_generator_optimizer,
        	discriminator_optimizer=self.stage1_discriminator_optimizer,
        	generator=self.stage1_generator,
        	discriminator=self.stage1_discriminator)

  def visualize_stage1(self):
	  """Running Tensorboard visualizations.
		"""
	  tb = TensorBoard(log_dir="logs/".format(time.time()))
	  tb.set_model(self.stage1_generator)
	  tb.set_model(self.stage1_discriminator)
	  tb.set_model(self.ca_network)
	  tb.set_model(self.embedding_compressor)

  def train_stage1(self):
	  """Trains the stage1 StackGAN.
    """
	  x_train, y_train, train_embeds = load_data(filename_path=StackGanStage1.filename_path_train, class_id_path=StackGanStage1.class_id_path_train,
      dataset_path=StackGanStage1.dataset_path, embeddings_path=StackGanStage1.embeddings_path_train, size=(64, 64))

	  x_test, y_test, test_embeds = load_data(filename_path=StackGanStage1.filename_path_test, class_id_path=StackGanStage1.class_id_path_test, 
      dataset_path=StackGanStage1.dataset_path, embeddings_path=StackGanStage1.embeddings_path_test, size=(64, 64))

	  real = np.ones((self.batch_size, 1), dtype='float') * 0.9
	  fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1

	  for epoch in range(self.epochs):
		  print(f'Epoch: {epoch}')

		  gen_loss = []
		  dis_loss = []

		  num_batches = int(x_train.shape[0] / self.batch_size)

		  for i in range(num_batches):

		    latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
		    embedding_text = train_embeds[i * self.batch_size:(i + 1) * self.batch_size]
		    compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)
		    compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, 128))
		    compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))

		    image_batch = x_train[i * self.batch_size:(i+1) * self.batch_size]
		    image_batch = (image_batch - 127.5) / 127.5

		    gen_images, _ = self.stage1_generator.predict([embedding_text, latent_space])

		    discriminator_loss = self.stage1_discriminator.train_on_batch([image_batch, compressed_embedding], 
					np.reshape(real, (self.batch_size, 1)))

		    discriminator_loss_gen = self.stage1_discriminator.train_on_batch([gen_images, compressed_embedding],
					np.reshape(fake, (self.batch_size, 1)))

		    discriminator_loss_wrong = self.stage1_discriminator.train_on_batch([gen_images[: self.batch_size-1], compressed_embedding[1:]], 
					np.reshape(fake[1:], (self.batch_size-1, 1)))

		    # Discriminator loss
		    d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_wrong))
		    dis_loss.append(d_loss)

		    print(f'Discriminator Loss: {d_loss}')

		    # Generator loss
		    g_loss = self.stage1_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],
					[K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])

		    print(f'Generator Loss: {g_loss}')
		    gen_loss.append(g_loss)

		    if epoch % 5 == 0:
				    latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
				    embedding_batch = test_embeds[0 : self.batch_size]
				    gen_images, _ = self.stage1_generator.predict_on_batch([embedding_batch, latent_space])

				    for i, image in enumerate(gen_images[:10]):
				        save_image(image, f'test/gen_1_{epoch}_{i}')

		    if epoch % 25 == 0:
		      self.stage1_generator.save_weights('weights/stage1_gen.h5')
		      self.stage1_discriminator.save_weights("weights/stage1_disc.h5")
		      self.ca_network.save_weights('weights/stage1_ca.h5')
		      self.embedding_compressor.save_weights('weights/stage1_embco.h5')
		      self.stage1_adversarial.save_weights('weights/stage1_adv.h5')      

	  self.stage1_generator.save_weights('weights/stage1_gen.h5')
	  self.stage1_discriminator.save_weights("weights/stage1_disc.h5")


In [None]:
stage1 = StackGanStage1()
stage1.train_stage1()

Generator Loss: [0.34689196944236755, 0.3445568084716797, 0.0011675840942189097]
Discriminator Loss: 0.16721034577858518
Generator Loss: [0.3330877125263214, 0.33111700415611267, 0.0009853534866124392]
Discriminator Loss: 0.1643047828820272
Generator Loss: [0.32858261466026306, 0.32644474506378174, 0.0010689407354220748]
Discriminator Loss: 0.16448607121401437
Generator Loss: [0.33131614327430725, 0.32853537797927856, 0.0013903887011110783]
Discriminator Loss: 0.16362930030481948
Generator Loss: [0.3318607807159424, 0.3291829228401184, 0.0013389221858233213]
Discriminator Loss: 0.16356134990473947
Generator Loss: [0.33617669343948364, 0.3326166272163391, 0.0017800278728827834]
Discriminator Loss: 0.1632985854321305
Generator Loss: [0.33573663234710693, 0.3327975273132324, 0.001469555776566267]
Discriminator Loss: 0.16331455534964334
Generator Loss: [0.32869425415992737, 0.32584816217422485, 0.0014230498345568776]
Discriminator Loss: 0.1634480877128226
Generator Loss: [0.336102455854415



Generator Loss: [0.349570631980896, 0.3471940755844116, 0.0011882728431373835]




Discriminator Loss: 0.16782294503082085




Generator Loss: [0.32981517910957336, 0.3272801637649536, 0.001267510000616312]




Discriminator Loss: 0.1699772397050765




Generator Loss: [0.32890400290489197, 0.32656970620155334, 0.0011671483516693115]




Discriminator Loss: 0.18600664493897057




Generator Loss: [0.32875779271125793, 0.3264427185058594, 0.0011575422249734402]




Discriminator Loss: 0.16845972325006642




Generator Loss: [0.3404940366744995, 0.33841314911842346, 0.0010404482018202543]




Discriminator Loss: 0.16679122864638884




Generator Loss: [0.33168163895606995, 0.329505980014801, 0.0010878229513764381]




Discriminator Loss: 0.16498785443639008




Generator Loss: [0.32930752635002136, 0.32752904295921326, 0.0008892389596439898]




Discriminator Loss: 0.16331208209453507




Generator Loss: [0.3292350769042969, 0.32655516266822815, 0.0013399519957602024]




Discriminator Loss: 0.16680414535221644




Generator Loss: [0.33617326617240906, 0.3329398036003113, 0.0016167263966053724]




Discriminator Loss: 0.17808512269334642




Generator Loss: [0.3309853971004486, 0.32845598459243774, 0.0012647053226828575]




Discriminator Loss: 0.16829663959458685




Generator Loss: [0.32821816205978394, 0.32625433802604675, 0.0009819114347919822]




Discriminator Loss: 0.16551147451718862




Generator Loss: [0.33114850521087646, 0.3292493224143982, 0.0009495978592894971]




Discriminator Loss: 0.16583868222130604




Generator Loss: [0.33827465772628784, 0.33640432357788086, 0.0009351635235361755]




Discriminator Loss: 0.16553107474192075




Generator Loss: [0.3271615207195282, 0.32571643590927124, 0.0007225432200357318]




Discriminator Loss: 0.16884309674634324




Generator Loss: [0.3588038682937622, 0.35563355684280396, 0.0015851580537855625]




Discriminator Loss: 0.1686988127821678




Generator Loss: [0.33960872888565063, 0.3367953598499298, 0.0014066861476749182]




Discriminator Loss: 0.18504910694554155




Generator Loss: [0.3951444923877716, 0.3928047716617584, 0.0011698566377162933]




Discriminator Loss: 0.1880569755153374




Generator Loss: [0.44827884435653687, 0.4465698003768921, 0.0008545215241611004]




Discriminator Loss: 0.23559399645364465




Generator Loss: [0.340821236371994, 0.33888399600982666, 0.0009686269331723452]




Discriminator Loss: 0.16345601806460763




Generator Loss: [0.36196795105934143, 0.359857976436615, 0.00105498475022614]




Discriminator Loss: 0.16908421897733206




Generator Loss: [0.33362486958503723, 0.33193784952163696, 0.0008435111958533525]




Discriminator Loss: 0.1802650562076451




Generator Loss: [0.33074572682380676, 0.3290255665779114, 0.0008600751752965152]




Discriminator Loss: 0.16391936307718424




Generator Loss: [0.327721506357193, 0.3255583643913269, 0.0010815657442435622]




Discriminator Loss: 0.163363995774489




Generator Loss: [0.32857567071914673, 0.32583487033843994, 0.0013703987933695316]




Discriminator Loss: 0.16717188506299863




Generator Loss: [0.38204097747802734, 0.3790574073791504, 0.0014917822554707527]




Discriminator Loss: 0.18047613905218896




Generator Loss: [0.3567337989807129, 0.35457515716552734, 0.0010793162509799004]




Discriminator Loss: 0.17883850710188653




Generator Loss: [0.40229085087776184, 0.40018385648727417, 0.001053491374477744]




Discriminator Loss: 0.17005273809627397




Generator Loss: [0.32866430282592773, 0.32651767134666443, 0.0010733201634138823]




Discriminator Loss: 0.167172594985459




Generator Loss: [0.6137602925300598, 0.6118025779724121, 0.000978859607130289]




Discriminator Loss: 0.1662841953723273




Generator Loss: [0.4335481524467468, 0.431615948677063, 0.000966096471529454]




Discriminator Loss: 0.17488791206687893




Generator Loss: [0.4234292209148407, 0.4214141368865967, 0.001007536193355918]




Discriminator Loss: 0.16671724437082958




Generator Loss: [0.32868170738220215, 0.32639580965042114, 0.0011429502628743649]




Discriminator Loss: 0.16865230096891537




Generator Loss: [0.33856263756752014, 0.33646637201309204, 0.0010481354547664523]




Discriminator Loss: 0.1675433806522051




Generator Loss: [0.3357126712799072, 0.3332224190235138, 0.0012451321817934513]




Discriminator Loss: 0.1750022378473659




Generator Loss: [0.3605329692363739, 0.3580000400543213, 0.0012664573732763529]




Discriminator Loss: 0.16620773186878068




Generator Loss: [0.33129143714904785, 0.3287670612335205, 0.0012621944770216942]




Discriminator Loss: 0.16723675170942442




Generator Loss: [0.4086776077747345, 0.4063156247138977, 0.0011809919960796833]




Discriminator Loss: 0.16806060721864924




Generator Loss: [0.3324609696865082, 0.3301434814929962, 0.0011587449116632342]




Discriminator Loss: 0.17790284450893523




Generator Loss: [0.3385346829891205, 0.3362070620059967, 0.0011638079304248095]




Discriminator Loss: 0.17094234473188408




Generator Loss: [0.33608147501945496, 0.3340541422367096, 0.0010136717464774847]




Discriminator Loss: 0.183657104133772




Generator Loss: [0.47825464606285095, 0.47605133056640625, 0.0011016508797183633]




Discriminator Loss: 0.16791165394261043




Generator Loss: [0.3288414180278778, 0.32698678970336914, 0.0009273100877180696]




Discriminator Loss: 0.1699952299659344




Generator Loss: [0.32964861392974854, 0.327084481716156, 0.001282065873965621]




Discriminator Loss: 0.1676368082553381




Generator Loss: [0.35142379999160767, 0.34917861223220825, 0.0011225922498852015]




Discriminator Loss: 0.18279491207795218




Generator Loss: [0.45124751329421997, 0.4489346146583557, 0.0011564460583031178]




Discriminator Loss: 0.1822815996947611




Generator Loss: [0.3702956438064575, 0.3680974245071411, 0.0010991133749485016]




Discriminator Loss: 0.16425416101628798




Generator Loss: [0.38869592547416687, 0.3865468502044678, 0.0010745381005108356]




Discriminator Loss: 0.16714221453548816




Generator Loss: [0.33197176456451416, 0.32971566915512085, 0.0011280494509264827]




Discriminator Loss: 0.16870891696453327




Generator Loss: [0.3416307866573334, 0.3390541076660156, 0.001288343919441104]




Discriminator Loss: 0.16632556264084997




Generator Loss: [0.3580024838447571, 0.3560272753238678, 0.0009876105468720198]




Discriminator Loss: 0.17509101399127758




Generator Loss: [0.3284602761268616, 0.3265739977359772, 0.000943132909014821]




Discriminator Loss: 0.1634475983005359




Generator Loss: [0.3279673755168915, 0.32644614577293396, 0.000760617374908179]




Discriminator Loss: 0.1643428481115734




Generator Loss: [0.32983872294425964, 0.3278132677078247, 0.0010127285495400429]




Discriminator Loss: 0.1638206550533141




Generator Loss: [0.3350582718849182, 0.3331434428691864, 0.000957409618422389]




Discriminator Loss: 0.1650939573883079




Generator Loss: [0.3349127173423767, 0.3327100872993469, 0.0011013124603778124]




Discriminator Loss: 0.16656488342596276




Generator Loss: [0.4681478440761566, 0.4661793112754822, 0.0009842619765549898]




Discriminator Loss: 0.1711461480778098




Generator Loss: [0.328888863325119, 0.326418936252594, 0.001234964351169765]




Discriminator Loss: 0.17541944434447032




Generator Loss: [0.3532147705554962, 0.35107874870300293, 0.0010680126724764705]




Discriminator Loss: 0.19315669859815898




Generator Loss: [0.33105525374412537, 0.3291296362876892, 0.0009628133848309517]




Discriminator Loss: 0.1690988003720122




Generator Loss: [0.339456707239151, 0.3371143341064453, 0.001171188778243959]




Discriminator Loss: 0.17045024776052742




Generator Loss: [0.3275028467178345, 0.3258702754974365, 0.0008162871235981584]




Discriminator Loss: 0.18158754778050934




Generator Loss: [0.35333552956581116, 0.3517950177192688, 0.0007702523143962026]




Discriminator Loss: 0.16750206128739364




Generator Loss: [0.32770222425460815, 0.32586944103240967, 0.0009163885843008757]




Discriminator Loss: 0.1716197830287456




Generator Loss: [0.3627641201019287, 0.36083322763442993, 0.0009654398891143501]




Discriminator Loss: 0.19107032186491324




Generator Loss: [0.3288339078426361, 0.32601726055145264, 0.0014083182904869318]




Discriminator Loss: 0.16713285901914787




Generator Loss: [0.3279878497123718, 0.32620978355407715, 0.0008890348835848272]




Discriminator Loss: 0.16773653461859794




Generator Loss: [0.32914188504219055, 0.3276864290237427, 0.000727728649508208]




Discriminator Loss: 0.169735772667309




Generator Loss: [0.3396962881088257, 0.3376564085483551, 0.0010199418757110834]




Discriminator Loss: 0.16586772568462038




Generator Loss: [0.3303203582763672, 0.3280467092990875, 0.0011368240229785442]




Discriminator Loss: 0.1732151413543761




Generator Loss: [0.3519201874732971, 0.3497545123100281, 0.001082835253328085]




Discriminator Loss: 0.17862261558320824




Generator Loss: [0.35845470428466797, 0.356209933757782, 0.001122383284382522]




Discriminator Loss: 0.19566898835546453




Generator Loss: [0.3510192036628723, 0.34929633140563965, 0.0008614416001364589]




Discriminator Loss: 0.17352400694380776




Generator Loss: [0.3737642467021942, 0.3721340000629425, 0.0008151269285008311]




Discriminator Loss: 0.18302850691634376




Generator Loss: [0.3763836622238159, 0.3749007284641266, 0.0007414601277559996]




Discriminator Loss: 0.18840650508172985




Generator Loss: [0.33370354771614075, 0.3319414258003235, 0.0008810547878965735]




Discriminator Loss: 0.17297730775953823




Generator Loss: [0.33565881848335266, 0.3342251181602478, 0.0007168569136410952]




Discriminator Loss: 0.16519515958907505




Generator Loss: [0.34063720703125, 0.33896583318710327, 0.0008356900652870536]




Discriminator Loss: 0.16648905567308248




Generator Loss: [0.337488055229187, 0.3358268141746521, 0.0008306150557473302]




Discriminator Loss: 0.16738515588804148




Generator Loss: [0.33138740062713623, 0.3297920227050781, 0.0007976944325491786]




Discriminator Loss: 0.17498804914021093




Generator Loss: [0.3348241448402405, 0.3332856297492981, 0.0007692547515034676]




Discriminator Loss: 0.16449279285188823




Generator Loss: [0.3307141065597534, 0.3289868235588074, 0.0008636402199044824]




Discriminator Loss: 0.16572536900639534




Generator Loss: [0.32839787006378174, 0.3268481194972992, 0.0007748737698420882]




Discriminator Loss: 0.1689552678135442




Generator Loss: [0.3295048475265503, 0.3279215097427368, 0.0007916648173704743]




Discriminator Loss: 0.16463232636397152




Generator Loss: [0.32943522930145264, 0.32781749963760376, 0.0008088629110716283]




Discriminator Loss: 0.16403884067017316




Generator Loss: [0.3364451825618744, 0.33508703112602234, 0.0006790811894461513]




Discriminator Loss: 0.1651581399087263




Generator Loss: [0.32736364006996155, 0.32588422298431396, 0.0007397111039608717]




Discriminator Loss: 0.16463007718448353




Generator Loss: [0.3284435570240021, 0.3265082538127899, 0.0009676589397713542]




Discriminator Loss: 0.16865374949520628




Generator Loss: [0.3280424475669861, 0.3261411190032959, 0.0009506630594842136]




Discriminator Loss: 0.16447857771800045




Generator Loss: [0.32800722122192383, 0.3255643844604492, 0.0012214126763865352]




Discriminator Loss: 0.16665287296075348




Generator Loss: [0.33401426672935486, 0.3320516347885132, 0.0009813089855015278]




Discriminator Loss: 0.1656033219114761




Generator Loss: [0.32801908254623413, 0.32613077759742737, 0.0009441527654416859]




Discriminator Loss: 0.1698079564883983




Generator Loss: [0.3293520212173462, 0.32669341564178467, 0.001329303253442049]




Discriminator Loss: 0.16692921517096693




Generator Loss: [0.3372160792350769, 0.3346555233001709, 0.001280277967453003]




Discriminator Loss: 0.17160852990934927




Generator Loss: [0.33069393038749695, 0.3283679783344269, 0.0011629746295511723]




Discriminator Loss: 0.16637303049196817




Generator Loss: [0.33190637826919556, 0.3299211859703064, 0.0009926013881340623]




Discriminator Loss: 0.16724687425005413




Generator Loss: [0.328888475894928, 0.32640960812568665, 0.0012394317891448736]




Discriminator Loss: 0.16485022523033876




Generator Loss: [0.3293684124946594, 0.3270271420478821, 0.0011706424411386251]




Discriminator Loss: 0.16756926527028781




Generator Loss: [0.3280171751976013, 0.3258439898490906, 0.001086593372747302]




Discriminator Loss: 0.1642353895474571




Generator Loss: [0.33348557353019714, 0.33132755756378174, 0.001079007750377059]




Discriminator Loss: 0.16514167845434713




Generator Loss: [0.33848658204078674, 0.3367270231246948, 0.0008797834161669016]




Discriminator Loss: 0.167165158168757




Generator Loss: [0.3418460190296173, 0.34033286571502686, 0.0007565838168375194]




Discriminator Loss: 0.17128264840914653




Generator Loss: [0.3365499973297119, 0.33498552441596985, 0.0007822381448931992]




Discriminator Loss: 0.16433508991451617




Generator Loss: [0.330282598733902, 0.3282490670681, 0.0010167695581912994]




Discriminator Loss: 0.16528202613244503




Generator Loss: [0.32750311493873596, 0.32581034302711487, 0.0008463849080726504]




Discriminator Loss: 0.1672878661111099




Generator Loss: [0.3269176781177521, 0.32558882236480713, 0.0006644277600571513]




Discriminator Loss: 0.16333226361598463




Generator Loss: [0.32952365279197693, 0.328016996383667, 0.000753330416046083]




Discriminator Loss: 0.1630746024421228




Generator Loss: [0.32777607440948486, 0.32615309953689575, 0.000811489881016314]




Discriminator Loss: 0.163377859793286




Generator Loss: [0.327444851398468, 0.3259052634239197, 0.0007697881665080786]




Discriminator Loss: 0.16309196262091064




Generator Loss: [0.3339134156703949, 0.33233028650283813, 0.0007915657479315996]




Discriminator Loss: 0.16331331969536222




Generator Loss: [0.327687531709671, 0.3258465826511383, 0.0009204679518006742]




Discriminator Loss: 0.1640526213591329




Generator Loss: [0.3315715789794922, 0.3296159505844116, 0.0009778074454516172]




Discriminator Loss: 0.16616034508297162




Generator Loss: [0.3393208086490631, 0.336974561214447, 0.0011731258127838373]




Discriminator Loss: 0.17326680873566147




Generator Loss: [0.32881975173950195, 0.32677197456359863, 0.0010238959221169353]




Discriminator Loss: 0.16561215991077916




Generator Loss: [0.3367987871170044, 0.33471956849098206, 0.0010396105935797095]




Discriminator Loss: 0.16546477580982355




Generator Loss: [0.3294481337070465, 0.32753872871398926, 0.0009546979563310742]




Discriminator Loss: 0.16449417629166874




Generator Loss: [0.32970130443573, 0.327822208404541, 0.0009395420202054083]




Discriminator Loss: 0.16384411305671165




Generator Loss: [0.32742470502853394, 0.32586199045181274, 0.0007813568226993084]




Discriminator Loss: 0.16494346872900678




Generator Loss: [0.3282647728919983, 0.3264898955821991, 0.0008874352788552642]




Discriminator Loss: 0.16327388569027335




Generator Loss: [0.3271483778953552, 0.32566267251968384, 0.0007428543176501989]




Discriminator Loss: 0.1636407354756102




Generator Loss: [0.3308268189430237, 0.32902437448501587, 0.0009012164082378149]




Discriminator Loss: 0.16407453057331622




Generator Loss: [0.3286234736442566, 0.3267214894294739, 0.000950989022385329]




Discriminator Loss: 0.1668444848019135




Generator Loss: [0.3289119601249695, 0.3270811140537262, 0.0009154232684522867]




Discriminator Loss: 0.16387204339253003




Generator Loss: [0.3289167582988739, 0.32731300592422485, 0.0008018704247660935]




Discriminator Loss: 0.16299110676959572




Generator Loss: [0.3295151889324188, 0.3279815912246704, 0.0007668000180274248]




Discriminator Loss: 0.16432771086033426




Generator Loss: [0.3310815691947937, 0.3292272090911865, 0.0009271849412471056]




Discriminator Loss: 0.16484704285494445




Generator Loss: [0.32982951402664185, 0.32802408933639526, 0.0009027062333188951]




Discriminator Loss: 0.17072134183490562




Generator Loss: [0.3288557529449463, 0.32702937722206116, 0.0009131896076723933]




Discriminator Loss: 0.16352938752856971




Generator Loss: [0.32782119512557983, 0.32568836212158203, 0.0010664190631359816]




Discriminator Loss: 0.16483305491806277




Generator Loss: [0.32847169041633606, 0.32626962661743164, 0.001101034227758646]




Discriminator Loss: 0.168911514409956




Generator Loss: [0.32950273156166077, 0.32774075865745544, 0.0008809860446490347]




Discriminator Loss: 0.16319497772838076




Generator Loss: [0.3295949697494507, 0.3274466395378113, 0.0010741609148681164]




Discriminator Loss: 0.16302514397443701




Generator Loss: [0.3280581533908844, 0.32641759514808655, 0.0008202718454413116]




Discriminator Loss: 0.16293269543439237




Generator Loss: [0.3272004723548889, 0.3257344365119934, 0.0007330219377763569]




Discriminator Loss: 0.16297066529182302




Generator Loss: [0.32940399646759033, 0.32768118381500244, 0.0008614059770479798]




Discriminator Loss: 0.16297172906422475




Generator Loss: [0.32755059003829956, 0.3261566162109375, 0.0006969815585762262]




Discriminator Loss: 0.1631070739063034




Generator Loss: [0.3272955119609833, 0.3259870707988739, 0.0006542244227603078]




Discriminator Loss: 0.1629131050168553




Generator Loss: [0.3270067274570465, 0.3255382180213928, 0.0007342563476413488]




Discriminator Loss: 0.163083693385488




Generator Loss: [0.32885217666625977, 0.32724225521087646, 0.0008049545576795936]




Epoch: 11
Discriminator Loss: 0.16388035123782174
Generator Loss: [0.33646780252456665, 0.33437883853912354, 0.0010444774525240064]
Discriminator Loss: 0.16408699674411764
Generator Loss: [0.3316735029220581, 0.3294943571090698, 0.001089578028768301]
Discriminator Loss: 0.16375214943082028
Generator Loss: [0.32774147391319275, 0.3257142901420593, 0.0010135932825505733]
Discriminator Loss: 0.16428598634513492
Generator Loss: [0.3327715992927551, 0.3307664394378662, 0.0010025853989645839]
Discriminator Loss: 0.16359606753201206
Generator Loss: [0.3276744782924652, 0.3258747160434723, 0.0008998815901577473]
Discriminator Loss: 0.16567851742354378
Generator Loss: [0.327513188123703, 0.3256192207336426, 0.0009469882352277637]
Discriminator Loss: 0.16666939260312574
Generator Loss: [0.3283228278160095, 0.32679229974746704, 0.0007652672356925905]
Discriminator Loss: 0.17686068294756296
Generator Loss: [0.33826926350593567, 0.33583498001098633, 0.001217134646140039]
Discriminator Loss: 0.16643

## Check test folder for gernerated images from Stage1 Generator

## Let's Implement Stage 2 Generator

In [11]:
############################################################
# Stage 2 Generator Network
############################################################

def concat_along_dims(inputs):
	"""Joins the conditioned text with the encoded image along the dimensions.

	Args:
		inputs: consisting of conditioned text and encoded images as [c,x].

	Returns:
		Joint block along the dimensions.
	"""
	c = inputs[0]
	x = inputs[1]

	c = K.expand_dims(c, axis=1)
	c = K.expand_dims(c, axis=1)
	c = K.tile(c, [1, 16, 16, 1])
	return K.concatenate([c, x], axis = 3)

def residual_block(input):
	"""Residual block with plain identity connections.

	Args:
		inputs: input layer or an encoded layer

	Returns:
		Layer with computed identity mapping.
	"""
	x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,
				kernel_initializer='he_uniform')(input)
	x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
	x = ReLU()(x)
	
	x = Conv2D(512, kernel_size=(3,3), padding='same', use_bias=False,
				kernel_initializer='he_uniform')(x)
	x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
	
	x = add([x, input])
	x = ReLU()(x)

	return x

def build_stage2_generator():
	"""Build the Stage 2 Generator Network using the conditioning text and images from stage 1.

	Returns:
		Stage 2 Generator Model for StackGAN.
	"""
	input_layer1 = Input(shape=(1024,))
	input_images = Input(shape=(64, 64, 3))

	# Conditioning Augmentation
	ca = Dense(256)(input_layer1)
	mls = LeakyReLU(alpha=0.2)(ca)
	c = Lambda(conditioning_augmentation)(mls)

	# Downsampling block
	x = ZeroPadding2D(padding=(1,1))(input_images)
	x = Conv2D(128, kernel_size=(3,3), strides=1, use_bias=False,
				kernel_initializer='he_uniform')(x)
	x = ReLU()(x)

	x = ZeroPadding2D(padding=(1,1))(x)
	x = Conv2D(256, kernel_size=(4,4), strides=2, use_bias=False,
				kernel_initializer='he_uniform')(x)
	x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
	x = ReLU()(x)

	x = ZeroPadding2D(padding=(1,1))(x)
	x = Conv2D(512, kernel_size=(4,4), strides=2, use_bias=False,
				kernel_initializer='he_uniform')(x)
	x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
	x = ReLU()(x)

	# Concatenate text conditioning block with the encoded image
	concat = concat_along_dims([c, x])

	# Residual Blocks
	x = ZeroPadding2D(padding=(1,1))(concat)
	x = Conv2D(512, kernel_size=(3,3), use_bias=False, kernel_initializer='he_uniform')(x)
	x = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x)
	x = ReLU()(x)

	x = residual_block(x)
	x = residual_block(x)
	x = residual_block(x)
	x = residual_block(x)

	# Upsampling Blocks
	x = UpSamplingBlock(x, 512)
	x = UpSamplingBlock(x, 256)
	x = UpSamplingBlock(x, 128)
	x = UpSamplingBlock(x, 64)

	x = Conv2D(3, kernel_size=(3,3), padding='same', use_bias=False, kernel_initializer='he_uniform')(x)
	x = Activation('tanh')(x)
	
	stage2_gen = Model(inputs=[input_layer1, input_images], outputs=[x, mls])
	return stage2_gen



In [None]:
generator_stage2 = build_stage2_generator()
generator_stage2.summary()

In [13]:

############################################################
# Stage 2 Discriminator Network
############################################################

def build_stage2_discriminator():
	"""Builds the Stage 2 Discriminator that uses the 256x256 resolution images from the generator
	and the compressed and spatially replicated embeddings.

	Returns:
		Stage 2 Discriminator Model for StackGAN.
	"""
	input_layer1 = Input(shape=(256, 256, 3))

	x = Conv2D(64, kernel_size=(4,4), padding='same', strides=2, use_bias=False,
				kernel_initializer='he_uniform')(input_layer1)
	x = LeakyReLU(alpha=0.2)(x)

	x = ConvBlock(x, 128)
	x = ConvBlock(x, 256)
	x = ConvBlock(x, 512)
	x = ConvBlock(x, 1024)
	x = ConvBlock(x, 2048)
	x = ConvBlock(x, 1024, (1,1), 1)
	x = ConvBlock(x, 512, (1,1), 1, False)

	x1 = ConvBlock(x, 128, (1,1), 1)
	x1 = ConvBlock(x1, 128, (3,3), 1)
	x1 = ConvBlock(x1, 512, (3,3), 1, False)

	x2 = add([x, x1])
	x2 = LeakyReLU(alpha=0.2)(x2)

	# Concatenate compressed and spatially replicated embedding
	input_layer2 = Input(shape=(4, 4, 128))
	concat = concatenate([x2, input_layer2])

	x3 = Conv2D(512, kernel_size=(1,1), strides=1, padding='same', kernel_initializer='he_uniform')(concat)
	x3 = BatchNormalization(gamma_initializer='ones', beta_initializer='zeros')(x3)
	x3 = LeakyReLU(alpha=0.2)(x3)

	# Flatten and add a FC layer
	x3 = Flatten()(x3)
	x3 = Dense(1)(x3)
	x3 = Activation('sigmoid')(x3)

	stage2_dis = Model(inputs=[input_layer1, input_layer2], outputs=[x3])
	return stage2_dis



In [None]:
discriminator_stage2 = build_stage2_discriminator()
discriminator_stage2.summary()

In [15]:

############################################################
# Stage 2 Adversarial Model
############################################################

def stage2_adversarial_network(stage2_disc, stage2_gen, stage1_gen):
	"""Stage 2 Adversarial Network.

	Args:
		stage2_disc: Stage 2 Discriminator Model.
		stage2_gen: Stage 2 Generator Model.
		stage1_gen: Stage 1 Generator Model.

	Returns:
		Stage 2 Adversarial network.
	"""
	conditioned_embedding = Input(shape=(1024, ))
	latent_space = Input(shape=(100, ))
	compressed_replicated = Input(shape=(4, 4, 128))
    
	#the discriminator is trained separately and stage1_gen already trained, and this is the reason why we freeze its layers by setting the property trainable=false
	input_images, ca = stage1_gen([conditioned_embedding, latent_space])
	stage2_disc.trainable = False
	stage1_gen.trainable = False

	images, ca2 = stage2_gen([conditioned_embedding, input_images])
	probability = stage2_disc([images, compressed_replicated])

	return Model(inputs=[conditioned_embedding, latent_space, compressed_replicated],
		outputs=[probability, ca2])


In [None]:
adversarial_stage2 = stage2_adversarial_network(discriminator_stage2, generator_stage2, generator)
adversarial_stage2.summary()

In [17]:

class StackGanStage2(object):
	"""StackGAN Stage 2 class.

	Args:
		epochs: Number of epochs
		z_dim: Latent space dimensions
		batch_size: Batch Size
		enable_function: If True, training function is decorated with tf.function
		stage2_generator_lr: Learning rate for stage 2 generator
		stage2_discriminator_lr: Learning rate for stage 2 discriminator
	"""
	def __init__(self, epochs=500, z_dim=100, batch_size=64, enable_function=True, stage2_generator_lr=0.0002, stage2_discriminator_lr=0.0002):
		self.epochs = epochs
		self.z_dim = z_dim
		self.enable_function = enable_function
		self.stage1_generator_lr = stage2_generator_lr
		self.stage1_discriminator_lr = stage2_discriminator_lr
		self.low_image_size = 64
		self.high_image_size = 256
		self.conditioning_dim = 128
		self.batch_size = batch_size
		self.stage2_generator_optimizer = Adam(lr=stage2_generator_lr, beta_1=0.5, beta_2=0.999)
		self.stage2_discriminator_optimizer = Adam(lr=stage2_discriminator_lr, beta_1=0.5, beta_2=0.999)
		self.stage1_generator = build_stage1_generator()
		self.stage1_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)
		self.stage1_generator.load_weights('weights/stage1_gen.h5')
		self.stage2_generator = build_stage2_generator()
		self.stage2_generator.compile(loss='binary_crossentropy', optimizer=self.stage2_generator_optimizer)

		self.stage2_discriminator = build_stage2_discriminator()
		self.stage2_discriminator.compile(loss='binary_crossentropy', optimizer=self.stage2_discriminator_optimizer)

		self.ca_network = build_ca_network()
		self.ca_network.compile(loss='binary_crossentropy', optimizer='Adam')

		self.embedding_compressor = build_embedding_compressor()
		self.embedding_compressor.compile(loss='binary_crossentropy', optimizer='Adam')

		self.stage2_adversarial = stage2_adversarial_network(self.stage2_discriminator, self.stage2_generator, self.stage1_generator)
		self.stage2_adversarial.compile(loss=['binary_crossentropy', adversarial_loss], loss_weights=[1, 2.0], optimizer=self.stage2_generator_optimizer)	

		self.checkpoint2 = tf.train.Checkpoint(
        	generator_optimizer=self.stage2_generator_optimizer,
        	discriminator_optimizer=self.stage2_discriminator_optimizer,
        	generator=self.stage2_generator,
        	discriminator=self.stage2_discriminator,
        	generator1=self.stage1_generator)

	def visualize_stage2(self):
		"""Running Tensorboard visualizations.
		"""
		tb = TensorBoard(log_dir="logs/".format(time.time()))
		tb.set_model(self.stage2_generator)
		tb.set_model(self.stage2_discriminator)

	def train_stage2(self):
		"""Trains Stage 2 StackGAN.
		"""
		x_high_train, y_high_train, high_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,
      dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(256, 256))

		x_high_test, y_high_test, high_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test, 
      dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(256, 256))

		x_low_train, y_low_train, low_train_embeds = load_data(filename_path=filename_path_train, class_id_path=class_id_path_train,
      dataset_path=dataset_path, embeddings_path=embeddings_path_train, size=(64, 64))

		x_low_test, y_low_test, low_test_embeds = load_data(filename_path=filename_path_test, class_id_path=class_id_path_test, 
      dataset_path=dataset_path, embeddings_path=embeddings_path_test, size=(64, 64))

		real = np.ones((self.batch_size, 1), dtype='float') * 0.9
		fake = np.zeros((self.batch_size, 1), dtype='float') * 0.1

		for epoch in range(self.epochs):
			print(f'Epoch: {epoch}')

			gen_loss = []
			disc_loss = []

			num_batches = int(x_high_train.shape[0] / self.batch_size)

			for i in range(num_batches):

				latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
				embedding_text = high_train_embeds[i * self.batch_size:(i + 1) * self.batch_size]
				compressed_embedding = self.embedding_compressor.predict_on_batch(embedding_text)
				compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, self.conditioning_dim))
				compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))

				image_batch = x_high_train[i * self.batch_size:(i+1) * self.batch_size]
				image_batch = (image_batch - 127.5) / 127.5
				
				low_res_fakes, _ = self.stage1_generator.predict([embedding_text, latent_space], verbose=3)
				high_res_fakes, _ = self.stage2_generator.predict([embedding_text, low_res_fakes], verbose=3)

				discriminator_loss = self.stage2_discriminator.train_on_batch([image_batch, compressed_embedding],
					np.reshape(real, (self.batch_size, 1)))

				discriminator_loss_gen = self.stage2_discriminator.train_on_batch([high_res_fakes, compressed_embedding],
					np.reshape(fake, (self.batch_size, 1)))

				discriminator_loss_fake = self.stage2_discriminator.train_on_batch([image_batch[:(self.batch_size-1)], compressed_embedding[1:]],
					np.reshape(fake[1:], (self.batch_size - 1, 1)))

				d_loss = 0.5 * np.add(discriminator_loss, 0.5 * np.add(discriminator_loss_gen, discriminator_loss_fake))
				disc_loss.append(d_loss)

				print(f'Discriminator Loss: {d_loss}')

				g_loss = self.stage2_adversarial.train_on_batch([embedding_text, latent_space, compressed_embedding],
					[K.ones((self.batch_size, 1)) * 0.9, K.ones((self.batch_size, 256)) * 0.9])
				gen_loss.append(g_loss)

				print(f'Generator Loss: {g_loss}')

				if epoch % 5 == 0:
					latent_space = np.random.normal(0, 1, size=(self.batch_size, self.z_dim))
					embedding_batch = high_test_embeds[0 : self.batch_size]

					low_fake_images, _ = self.stage1_generator.predict([embedding_batch, latent_space], verbose=3)
					high_fake_images, _ = self.stage2_generator.predict([embedding_batch, low_fake_images], verbose=3)

					for i, image in enumerate(high_fake_images[:10]):
					    save_image(image, f'results_stage2/gen_{epoch}_{i}.png')

				if epoch % 10 == 0:
					self.stage2_generator.save_weights('weights/stage2_gen.h5')
					self.stage2_discriminator.save_weights("weights/stage2_disc.h5")
					self.ca_network.save_weights('weights/stage2_ca.h5')
					self.embedding_compressor.save_weights('weights/stage2_embco.h5')
					self.stage2_adversarial.save_weights('weights/stage2_adv.h5')

		self.stage2_generator.save_weights('weights/stage2_gen.h5')
		self.stage2_discriminator.save_weights("weights/stage2_disc.h5")


In [None]:
stage2 = StackGanStage2()
stage2.train_stage2()