In [1]:
# lfw-created_2 (400 images of our data)
!gdown https://drive.google.com/uc?id=1NB1deCJ4XZPvqhTXiRD1AG6FgKHu-h1Y

Downloading...
From: https://drive.google.com/uc?id=1NB1deCJ4XZPvqhTXiRD1AG6FgKHu-h1Y
To: /content/lfw-created2.zip
0.00B [00:00, ?B/s]3.63MB [00:00, 114MB/s]


In [2]:
# faces_em_2 (embeddings of our data of 372)
!gdown https://drive.google.com/uc?id=1QrHl3tv7J89S9CyKz4m6_R8F4UXoKHh7

Downloading...
From: https://drive.google.com/uc?id=1QrHl3tv7J89S9CyKz4m6_R8F4UXoKHh7
To: /content/faces_em_2.zip
15.7MB [00:00, 25.4MB/s]


In [0]:
#! echo '0' > epoch.txt

In [4]:
# stage1_gen
!gdown https://drive.google.com/uc?id=1DpaUvhl7nd9BdVBeNEowBL6unQMWhRLg

Downloading...
From: https://drive.google.com/uc?id=1DpaUvhl7nd9BdVBeNEowBL6unQMWhRLg
To: /content/stage1_gen.h5
41.1MB [00:00, 57.5MB/s]


In [5]:
# stage2_dis
!gdown https://drive.google.com/uc?id=1d3OY4xQX6qGrxw42iqdOU771dHo4DR9o

Downloading...
From: https://drive.google.com/uc?id=1d3OY4xQX6qGrxw42iqdOU771dHo4DR9o
To: /content/stage1_dis.h5
12.4MB [00:00, 26.7MB/s]


In [0]:
from zipfile import ZipFile


with ZipFile("lfw-created2.zip", 'r') as zip2:
  zip2.extractall()
  zip2.close()
  
with ZipFile("faces_em_2.zip", 'r') as zip1:
  zip1.extractall()
  zip1.close()

In [7]:
!ls
!mkdir logs
!mkdir results
!ls

faces_em	lfw-created2	  sample_data	 stage1_gen.h5
faces_em_2.zip	lfw-created2.zip  stage1_dis.h5
faces_em	lfw-created2	  logs	   sample_data	  stage1_gen.h5
faces_em_2.zip	lfw-created2.zip  results  stage1_dis.h5


In [8]:
import os
import pickle
import random
import time

import PIL
import numpy as np
import pandas as pd
import tensorflow as tf
from PIL import Image
from keras import Input, Model
from keras import backend as K
from keras.callbacks import TensorBoard
from keras.layers import Dense, LeakyReLU, BatchNormalization, ReLU, Reshape, UpSampling2D, Conv2D, Activation, \
    concatenate, Flatten, Lambda, Concatenate, ZeroPadding2D
from keras.layers import add
from keras.optimizers import Adam
from matplotlib import pyplot as plt

Using TensorFlow backend.


In [0]:
def build_ca_model():
    """
    Get conditioning augmentation model.
    Takes an embedding of shape (1024,) and returns a tensor of shape (256,)
    """
    input_layer = Input(shape=(1024,))
    x = Dense(256)(input_layer)
    x = LeakyReLU(alpha=0.2)(x)
    model = Model(inputs=[input_layer], outputs=[x])
    return model


In [0]:
def build_embedding_compressor_model():
    """
    Build embedding compressor model
    """
    input_layer = Input(shape=(1024,))
    x = Dense(128)(input_layer)
    x = ReLU()(x)
    model = Model(inputs=[input_layer], outputs=[x])
    return model

In [0]:
def generate_c(x):
    mean = x[:, :128]
    log_sigma = x[:, 128:]

    stddev = K.exp(log_sigma)
    epsilon = K.random_normal(shape=K.constant((mean.shape[1],), dtype='int32'))
    c = stddev * epsilon + mean

    return c

In [0]:
def build_stage1_generator():
    """
    Builds a generator model used in Stage-I
    """
    input_layer = Input(shape=(1024,))
    x = Dense(256)(input_layer)
    mean_logsigma = LeakyReLU(alpha=0.2)(x)

    c = Lambda(generate_c)(mean_logsigma)

    input_layer2 = Input(shape=(100,))

    gen_input = Concatenate(axis=1)([c, input_layer2])

    x = Dense(128 * 8 * 4 * 4, use_bias=False)(gen_input)
    x = ReLU()(x)

    x = Reshape((4, 4, 128 * 8), input_shape=(128 * 8 * 4 * 4,))(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(512, kernel_size=3, padding="same", strides=1, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(256, kernel_size=3, padding="same", strides=1, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(128, kernel_size=3, padding="same", strides=1, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(64, kernel_size=3, padding="same", strides=1, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = Conv2D(3, kernel_size=3, padding="same", strides=1, use_bias=False)(x)
    x = Activation(activation='tanh')(x)

    stage1_gen = Model(inputs=[input_layer, input_layer2], outputs=[x, mean_logsigma])
    return stage1_gen

In [0]:
def residual_block(input):
    """
    Residual block in the generator network
    """
    x = Conv2D(128 * 4, kernel_size=(3, 3), padding='same', strides=1)(input)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = Conv2D(128 * 4, kernel_size=(3, 3), strides=1, padding='same')(x)
    x = BatchNormalization()(x)

    x = add([x, input])
    x = ReLU()(x)

    return x

In [0]:
def joint_block(inputs):
    c = inputs[0]
    x = inputs[1]

    c = K.expand_dims(c, axis=1)
    c = K.expand_dims(c, axis=1)
    c = K.tile(c, [1, 16, 16, 1])
    return K.concatenate([c, x], axis=3)

In [0]:
def build_stage2_generator():
    """
    Create Stage-II generator containing the CA Augmentation Network,
    the image encoder and the generator network
    """

    # 1. CA Augmentation Network
    input_layer = Input(shape=(1024,))
    input_lr_images = Input(shape=(64, 64, 3))

    ca = Dense(256)(input_layer)
    mean_logsigma = LeakyReLU(alpha=0.2)(ca)
    c = Lambda(generate_c)(mean_logsigma)

    # 2. Image Encoder
    x = ZeroPadding2D(padding=(1, 1))(input_lr_images)
    x = Conv2D(128, kernel_size=(3, 3), strides=1, use_bias=False)(x)
    x = ReLU()(x)

    x = ZeroPadding2D(padding=(1, 1))(x)
    x = Conv2D(256, kernel_size=(4, 4), strides=2, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = ZeroPadding2D(padding=(1, 1))(x)
    x = Conv2D(512, kernel_size=(4, 4), strides=2, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    # 3. Joint
    c_code = Lambda(joint_block)([c, x])

    x = ZeroPadding2D(padding=(1, 1))(c_code)
    x = Conv2D(512, kernel_size=(3, 3), strides=1, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    # 4. Residual blocks
    x = residual_block(x)
    x = residual_block(x)
    x = residual_block(x)
    x = residual_block(x)

    # 5. Upsampling blocks
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(512, kernel_size=3, padding="same", strides=1, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(256, kernel_size=3, padding="same", strides=1, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(128, kernel_size=3, padding="same", strides=1, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(64, kernel_size=3, padding="same", strides=1, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    x = Conv2D(3, kernel_size=3, padding="same", strides=1, use_bias=False)(x)
    x = Activation('tanh')(x)

    model = Model(inputs=[input_layer, input_lr_images], outputs=[x, mean_logsigma])
    return model

In [0]:
def build_stage2_discriminator():
    """
    Create Stage-II discriminator network
    """
    input_layer = Input(shape=(256, 256, 3))

    x = Conv2D(64, (4, 4), padding='same', strides=2, input_shape=(256, 256, 3), use_bias=False)(input_layer)
    x = LeakyReLU(alpha=0.2)(x)

    x = Conv2D(128, (4, 4), padding='same', strides=2, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    x = Conv2D(256, (4, 4), padding='same', strides=2, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    x = Conv2D(512, (4, 4), padding='same', strides=2, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    x = Conv2D(1024, (4, 4), padding='same', strides=2, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    x = Conv2D(2048, (4, 4), padding='same', strides=2, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    x = Conv2D(1024, (1, 1), padding='same', strides=1, use_bias=False)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    x = Conv2D(512, (1, 1), padding='same', strides=1, use_bias=False)(x)
    x = BatchNormalization()(x)

    x2 = Conv2D(128, (1, 1), padding='same', strides=1, use_bias=False)(x)
    x2 = BatchNormalization()(x2)
    x2 = LeakyReLU(alpha=0.2)(x2)

    x2 = Conv2D(128, (3, 3), padding='same', strides=1, use_bias=False)(x2)
    x2 = BatchNormalization()(x2)
    x2 = LeakyReLU(alpha=0.2)(x2)

    x2 = Conv2D(512, (3, 3), padding='same', strides=1, use_bias=False)(x2)
    x2 = BatchNormalization()(x2)

    added_x = add([x, x2])
    added_x = LeakyReLU(alpha=0.2)(added_x)

    input_layer2 = Input(shape=(4, 4, 128))

    merged_input = concatenate([added_x, input_layer2])

    x3 = Conv2D(64 * 8, kernel_size=1, padding="same", strides=1)(merged_input)
    x3 = BatchNormalization()(x3)
    x3 = LeakyReLU(alpha=0.2)(x3)
    x3 = Flatten()(x3)
    x3 = Dense(1)(x3)
    x3 = Activation('sigmoid')(x3)

    stage2_dis = Model(inputs=[input_layer, input_layer2], outputs=[x3])
    return stage2_dis

In [0]:
def build_adversarial_model(gen_model2, dis_model, gen_model1):
    """
    Create adversarial model
    """
    embeddings_input_layer = Input(shape=(1024, ))
    noise_input_layer = Input(shape=(100, ))
    compressed_embedding_input_layer = Input(shape=(4, 4, 128))

    gen_model1.trainable = False
    dis_model.trainable = False

    lr_images, mean_logsigma1 = gen_model1([embeddings_input_layer, noise_input_layer])
    hr_images, mean_logsigma2 = gen_model2([embeddings_input_layer, lr_images])
    valid = dis_model([hr_images, compressed_embedding_input_layer])

    model = Model(inputs=[embeddings_input_layer, noise_input_layer, compressed_embedding_input_layer], outputs=[valid, mean_logsigma2])
    return model


In [0]:
"""
Dataset loading related methods
"""


def load_class_ids(class_info_file_path):
    """
    Load class ids from class_info.pickle file
    """
    with open(class_info_file_path, 'rb') as f:
        class_ids = pickle.load(f, encoding='latin1')
        return class_ids


def load_embeddings(embeddings_file_path):
    """
    Function to load embeddings
    """
    with open(embeddings_file_path, 'rb') as f:
        embeddings = pickle.load(f, encoding='latin1')
        embeddings = np.array(embeddings)
        print('embeddings: ', embeddings.shape)
    return embeddings


def load_filenames(filenames_file_path):
    """
    Load filenames.pickle file and return a list of all file names
    """
    with open(filenames_file_path, 'rb') as f:
        filenames = pickle.load(f, encoding='latin1')
    return filenames

'''
def load_bounding_boxes(dataset_dir):
    """
    Load bounding boxes and return a dictionary of file names and corresponding bounding boxes
    """
    # Paths
    bounding_boxes_path = os.path.join(dataset_dir, 'bounding_boxes.txt')
    file_paths_path = os.path.join(dataset_dir, 'images.txt')

    # Read bounding_boxes.txt and images.txt file
    df_bounding_boxes = pd.read_csv(bounding_boxes_path,
                                    delim_whitespace=True, header=None).astype(int)
    df_file_names = pd.read_csv(file_paths_path, delim_whitespace=True, header=None)

    # Create a list of file names
    file_names = df_file_names[1].tolist()

    # Create a dictionary of file_names and bounding boxes
    filename_boundingbox_dict = {img_file[:-4]: [] for img_file in file_names[:2]}

    # Assign a bounding box to the corresponding image
    for i in range(0, len(file_names)):
        # Get the bounding box
        bounding_box = df_bounding_boxes.iloc[i][1:].tolist()
        key = file_names[i][:-4]
        filename_boundingbox_dict[key] = bounding_box

    return filename_boundingbox_dict
'''

def get_img(img_path, image_size):
    """
    Load and resize images
    """
    img = Image.open(img_path).convert('RGB')
    width, height = img.size
    '''
    if bbox is not None:
        R = int(np.maximum(bbox[2], bbox[3]) * 0.75)
        center_x = int((2 * bbox[0] + bbox[2]) / 2)
        center_y = int((2 * bbox[1] + bbox[3]) / 2)
        y1 = np.maximum(0, center_y - R)
        y2 = np.minimum(height, center_y + R)
        x1 = np.maximum(0, center_x - R)
        x2 = np.minimum(width, center_x + R)
        img = img.crop([x1, y1, x2, y2])
    '''
    img = img.resize(image_size, PIL.Image.BILINEAR)
    return img


def load_dataset(filenames_file_path, class_info_file_path, cub_dataset_dir, embeddings_file_path, image_size):
    filenames = load_filenames(filenames_file_path)
    class_ids = load_class_ids(class_info_file_path)
    #bounding_boxes = load_bounding_boxes(cub_dataset_dir)
    all_embeddings = load_embeddings(embeddings_file_path)

    X, y, embeddings = [], [], []

    print("All embeddings shape:", all_embeddings.shape)

    for index, filename in enumerate(filenames):
        #bounding_box = bounding_boxes[filename]

        try:
            # Load images
            img_name = '{}/{}'.format(cub_dataset_dir, filename)
            img = get_img(img_name, image_size)

            all_embeddings1 = all_embeddings[index, :, :]

            embedding_ix = random.randint(0, all_embeddings1.shape[0] - 1)
            embedding = all_embeddings1[embedding_ix, :]

            X.append(np.array(img))
            y.append(class_ids[index])
            embeddings.append(embedding)
        except Exception as e:
            print(e)

    X = np.array(X)
    y = np.array(y)
    embeddings = np.array(embeddings)

    return X, y, embeddings


In [0]:
"""
Loss functions
"""


def KL_loss(y_true, y_pred):
    mean = y_pred[:, :128]
    logsigma = y_pred[:, :128]
    loss = -logsigma + .5 * (-1 + K.exp(2. * logsigma) + K.square(mean))
    loss = K.mean(loss)
    return loss


def custom_generator_loss(y_true, y_pred):
    # Calculate binary cross entropy loss
    return K.binary_crossentropy(y_true, y_pred)


def write_log(callback, name, loss, batch_no):
    """
    Write training summary to TensorBoard
    """
    summary = tf.Summary()
    summary_value = summary.value.add()
    summary_value.simple_value = loss
    summary_value.tag = name
    callback.writer.add_summary(summary, batch_no)
    callback.writer.flush()


def save_rgb_img(img, path):
    """
    Save an rgb image
    """
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.imshow(img)
    ax.axis("off")
    ax.set_title("Image")

    plt.savefig(path)
    plt.close()


In [20]:
# stage2_gen 6
!gdown https://drive.google.com/uc?id=1Tczc0MmiMqOJ70bQl24_IpXd1GFjxHeH

Downloading...
From: https://drive.google.com/uc?id=1Tczc0MmiMqOJ70bQl24_IpXd1GFjxHeH
To: /content/stage2_gen.h5
115MB [00:02, 56.8MB/s] 


In [21]:
# stage2_dis 6
!gdown https://drive.google.com/uc?id=1o99TtXHEYp8eB8Er7I5W18JY8muBY96I

Downloading...
From: https://drive.google.com/uc?id=1o99TtXHEYp8eB8Er7I5W18JY8muBY96I
To: /content/stage2_dis.h5
194MB [00:03, 64.2MB/s]


In [0]:
! echo '358' > epoch.txt

In [0]:
if __name__ == '__main__':
    data_dir = "/content/faces_em"
    train_dir = data_dir + "/train"
    test_dir = data_dir + "/test"
    hr_image_size = (256, 256)
    lr_image_size = (64, 64)
    batch_size = 8
    z_dim = 100
    stage1_generator_lr = 0.0002
    stage1_discriminator_lr = 0.0002
    stage1_lr_decay_step = 600
    epochs = 1000
    condition_dim = 128

    embeddings_file_path_train = train_dir + "/embed_train.pickle"
    embeddings_file_path_test = test_dir + "/embed_test.pickle"

    filenames_file_path_train = train_dir + "/filenames_train.pickle"
    filenames_file_path_test = test_dir + "/filenames_test.pickle"

    class_info_file_path_train = train_dir + "/class_info.pickle"
    class_info_file_path_test = test_dir + "/class_info.pickle"

    cub_dataset_dir = "/content/lfw-created2"

    # Define optimizers
    dis_optimizer = Adam(lr=stage1_discriminator_lr, beta_1=0.5, beta_2=0.999)
    gen_optimizer = Adam(lr=stage1_generator_lr, beta_1=0.5, beta_2=0.999)

    """
    Load datasets
    """
    X_hr_train, y_hr_train, embeddings_train = load_dataset(filenames_file_path=filenames_file_path_train,
                                                            class_info_file_path=class_info_file_path_train,
                                                            cub_dataset_dir=cub_dataset_dir,
                                                            embeddings_file_path=embeddings_file_path_train,
                                                            image_size=(256, 256))

    X_hr_test, y_hr_test, embeddings_test = load_dataset(filenames_file_path=filenames_file_path_test,
                                                         class_info_file_path=class_info_file_path_test,
                                                         cub_dataset_dir=cub_dataset_dir,
                                                         embeddings_file_path=embeddings_file_path_test,
                                                         image_size=(256, 256))

    X_lr_train, y_lr_train, _ = load_dataset(filenames_file_path=filenames_file_path_train,
                                             class_info_file_path=class_info_file_path_train,
                                             cub_dataset_dir=cub_dataset_dir,
                                             embeddings_file_path=embeddings_file_path_train,
                                             image_size=(64, 64))

    X_lr_test, y_lr_test, _ = load_dataset(filenames_file_path=filenames_file_path_test,
                                           class_info_file_path=class_info_file_path_test,
                                           cub_dataset_dir=cub_dataset_dir,
                                           embeddings_file_path=embeddings_file_path_test,
                                           image_size=(64, 64))

    """
    Build and compile models
    """
    stage2_dis = build_stage2_discriminator()
    stage2_dis.compile(loss='binary_crossentropy', optimizer=dis_optimizer)

    stage1_gen = build_stage1_generator()
    stage1_gen.compile(loss="binary_crossentropy", optimizer=gen_optimizer)

    stage1_gen.load_weights("stage1_gen.h5")

    stage2_gen = build_stage2_generator()
    stage2_gen.compile(loss="binary_crossentropy", optimizer=gen_optimizer)

    embedding_compressor_model = build_embedding_compressor_model()
    embedding_compressor_model.compile(loss='binary_crossentropy', optimizer='adam')

    adversarial_model = build_adversarial_model(stage2_gen, stage2_dis, stage1_gen)
    adversarial_model.compile(loss=['binary_crossentropy', KL_loss], loss_weights=[1.0, 2.0],
                              optimizer=gen_optimizer, metrics=None)

    tensorboard = TensorBoard(log_dir="logs/".format(time.time()))
    tensorboard.set_model(stage2_gen)
    tensorboard.set_model(stage2_dis)

    # Generate an array containing real and fake values
    # Apply label smoothing
    real_labels = np.ones((batch_size, 1), dtype=float) * 0.9
    fake_labels = np.zeros((batch_size, 1), dtype=float) * 0.1

    # Resuming the training
    file1 = open("epoch.txt",'r') #directory change
    epoch_start = 0
    epoch_start = int(file1.read())
    if(epoch_start!=0):
      stage2_gen.load_weights("stage2_gen.h5") #directory change
      stage2_dis.load_weights("stage2_dis.h5") #directory change
      
    for epoch in range(epoch_start, epochs):
        print("========================================")
        print("Epoch is:", epoch)

        gen_losses = []
        dis_losses = []

        # Load data and train model
        number_of_batches = int(X_hr_train.shape[0] / batch_size)
        print("Number of batches:{}".format(number_of_batches))
        for index in range(number_of_batches):
            print("Batch:{}".format(index+1))

            # Create a noise vector
            z_noise = np.random.normal(0, 1, size=(batch_size, z_dim))
            X_hr_train_batch = X_hr_train[index * batch_size:(index + 1) * batch_size]
            embedding_batch = embeddings_train[index * batch_size:(index + 1) * batch_size]
            X_hr_train_batch = (X_hr_train_batch - 127.5) / 127.5

            # Generate fake images
            lr_fake_images, _ = stage1_gen.predict([embedding_batch, z_noise], verbose=3)
            hr_fake_images, _ = stage2_gen.predict([embedding_batch, lr_fake_images], verbose=3)

            """
            4. Generate compressed embeddings
            """
            compressed_embedding = embedding_compressor_model.predict_on_batch(embedding_batch)
            compressed_embedding = np.reshape(compressed_embedding, (-1, 1, 1, condition_dim))
            compressed_embedding = np.tile(compressed_embedding, (1, 4, 4, 1))

            """
            5. Train the discriminator model
            """
            dis_loss_real = stage2_dis.train_on_batch([X_hr_train_batch, compressed_embedding],
                                                      np.reshape(real_labels, (batch_size, 1)))
            dis_loss_fake = stage2_dis.train_on_batch([hr_fake_images, compressed_embedding],
                                                      np.reshape(fake_labels, (batch_size, 1)))
            dis_loss_wrong = stage2_dis.train_on_batch([X_hr_train_batch[:(batch_size - 1)], compressed_embedding[1:]],
                                                       np.reshape(fake_labels[1:], (batch_size-1, 1)))
            d_loss = 0.5 * np.add(dis_loss_real, 0.5 * np.add(dis_loss_wrong,  dis_loss_fake))
            print("d_loss:{}".format(d_loss))

            """
            Train the adversarial model
            """
            g_loss = adversarial_model.train_on_batch([embedding_batch, z_noise, compressed_embedding],
                                                                [K.ones((batch_size, 1)) * 0.9, K.ones((batch_size, 256)) * 0.9])

            print("g_loss:{}".format(g_loss))

            dis_losses.append(d_loss)
            gen_losses.append(g_loss)

        """
        Save losses to Tensorboard after each epoch
        """
        write_log(tensorboard, 'discriminator_loss', np.mean(dis_losses), epoch)
        #write_log(tensorboard, 'generator_loss', np.mean(gen_losses)[0], epoch)
        write_log(tensorboard, 'generator_loss', np.mean(gen_losses), epoch)

        # Generate and save images after every 2nd epoch
        if epoch % 2 == 0:
            # z_noise2 = np.random.uniform(-1, 1, size=(batch_size, z_dim))
            z_noise2 = np.random.normal(0, 1, size=(batch_size, z_dim))
            embedding_batch = embeddings_test[0:batch_size]

            lr_fake_images, _ = stage1_gen.predict([embedding_batch, z_noise2], verbose=3)
            hr_fake_images, _ = stage2_gen.predict([embedding_batch, lr_fake_images], verbose=3)

            # saving weights as a checkpoint
            stage2_gen.save_weights("stage2_gen.h5") #directory change
            stage2_dis.save_weights("stage2_dis.h5") #directory change
            file = open("epoch.txt",'w') #directory change
            file.write(str(epoch))
            
            # Save images
            for i, img in enumerate(hr_fake_images[:10]):
                save_rgb_img(img, "results/gen_{}_{}.png".format(epoch, i))

        if epoch % 4 == 0:
          
            import shutil
            shutil.make_archive("results", 'zip', "results")
            shutil.make_archive("logs", 'zip', "logs")
            
    # Saving the models
    stage2_gen.save_weights("stage2_gen.h5")
    stage2_dis.save_weights("stage2_dis.h5")

Instructions for updating:
Colocations handled automatically by placer.
embeddings:  (300, 1, 1024)
All embeddings shape: (300, 1, 1024)
embeddings:  (73, 1, 1024)
All embeddings shape: (73, 1, 1024)
embeddings:  (300, 1, 1024)
All embeddings shape: (300, 1, 1024)
embeddings:  (73, 1, 1024)
All embeddings shape: (73, 1, 1024)
Epoch is: 358
Number of batches:37
Batch:1
Instructions for updating:
Use tf.cast instead.


  'Discrepancy between trainable weights and collected trainable'


d_loss:0.16690214298432693
g_loss:[0.39230898, 0.3922538, 2.7596794e-05]
Batch:2
d_loss:0.16489921542233787
g_loss:[0.9033159, 0.90326905, 2.3410003e-05]
Batch:3
d_loss:0.28628901904448867
g_loss:[0.39628783, 0.396243, 2.2416605e-05]
Batch:4
d_loss:0.463080495595932
g_loss:[2.9876335, 2.9875913, 2.1136368e-05]
Batch:5
d_loss:0.18506181798875332
g_loss:[0.3647068, 0.3646589, 2.3956043e-05]
Batch:6
d_loss:0.18017483176663518
g_loss:[0.59029454, 0.5902454, 2.4556723e-05]
Batch:7
d_loss:0.1707058663596399
g_loss:[0.5099668, 0.50990146, 3.2677577e-05]
Batch:8
d_loss:0.16675298305926844
g_loss:[0.33648306, 0.33638388, 4.959319e-05]
Batch:9
d_loss:0.16887802464771084
g_loss:[0.34086463, 0.34077358, 4.5519922e-05]
Batch:10
d_loss:0.16444732126547024
g_loss:[1.5215632, 1.521497, 3.3058495e-05]
Batch:11
d_loss:0.16528231388656422
g_loss:[0.3471945, 0.34715888, 1.7802498e-05]
Batch:12
d_loss:0.16933421092107892
g_loss:[0.32894492, 0.32891634, 1.4287459e-05]
Batch:13
d_loss:0.1665554873761721
g_lo

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch is: 359
Number of batches:37
Batch:1
d_loss:0.16428913525305688
g_loss:[0.37904632, 0.37902564, 1.0334007e-05]
Batch:2
d_loss:0.1750691650668159
g_loss:[0.5371888, 0.5371698, 9.5016985e-06]
Batch:3
d_loss:0.17679546517319977
g_loss:[0.34278753, 0.3427669, 1.0310239e-05]
Batch:4
d_loss:0.7428201138973236
g_loss:[0.36628395, 0.3662638, 1.0076852e-05]
Batch:5
d_loss:0.17559250281192362
g_loss:[0.44354182, 0.4435251, 8.3568775e-06]
Batch:6
d_loss:0.16706668032566085
g_loss:[0.38558036, 0.38556623, 7.056643e-06]
Batch:7
d_loss:0.17112430441193283
g_loss:[0.38412604, 0.3841135, 6.277922e-06]
Batch:8
d_loss:0.1663431156775914
g_loss:[3.2315216, 3.2315102, 5.702591e-06]
Batch:9
d_loss:0.16489652528252918
g_loss:[0.3439087, 0.34388974, 9.47957e-06]
Batch:10
d_loss:0.16762544424273074
g_loss:[1.186135, 1.1861229, 6.099726e-06]
Batch:11
d_loss:0.16850861048442312
g_loss:[0.7856158, 0.78558004, 1.789334e-05]
Batch:12
d_loss:0.16574821341782808
g_loss:[1.7327244, 1.7327034, 1.0515991e-05]
Bat

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch is: 361
Number of batches:37
Batch:1
d_loss:0.1712478060508147
g_loss:[0.3942367, 0.39401573, 0.00011049841]
Batch:2
d_loss:0.16489135660231113
g_loss:[0.41242802, 0.4122141, 0.00010696412]
Batch:3
d_loss:0.16383568433229811
g_loss:[0.3956641, 0.39545718, 0.00010345699]
Batch:4
d_loss:0.5611960887908936
g_loss:[0.35281426, 0.35261416, 0.00010004475]
Batch:5
d_loss:0.16450932635052595
g_loss:[0.38034663, 0.38015312, 9.6758864e-05]
Batch:6
d_loss:0.17739890981465578
g_loss:[1.6605104, 1.6603225, 9.393359e-05]
Batch:7
d_loss:0.1703534775879234
g_loss:[0.3494528, 0.34926963, 9.158843e-05]
Batch:8
d_loss:0.17936442006612197
g_loss:[2.2971237, 2.2969453, 8.919567e-05]
Batch:9
d_loss:0.16687792031552817
g_loss:[0.3880892, 0.3878988, 9.520725e-05]
Batch:10
d_loss:0.17273831355851144
g_loss:[0.36764807, 0.36745316, 9.745105e-05]
Batch:11
d_loss:0.1663122421305161
g_loss:[2.0557108, 2.0555182, 9.633919e-05]
Batch:12
d_loss:0.17631741933291778
g_loss:[0.39029872, 0.39010924, 9.4746894e-05]


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch is: 363
Number of batches:37
Batch:1
d_loss:0.16506440707598813
g_loss:[0.3874404, 0.3873675, 3.646099e-05]
Batch:2
d_loss:0.16533103340043453
g_loss:[0.3343816, 0.33430964, 3.599142e-05]
Batch:3
d_loss:0.16451244079507887
g_loss:[0.3525255, 0.35245454, 3.548703e-05]
Batch:4
d_loss:0.4655824452638626
g_loss:[0.34283406, 0.3427642, 3.4923687e-05]
Batch:5
d_loss:0.16984194668475538
g_loss:[0.41275594, 0.41268718, 3.437587e-05]
Batch:6
d_loss:0.16801231789577287
g_loss:[0.44423848, 0.4441709, 3.379998e-05]
Batch:7
d_loss:0.16714832917205058
g_loss:[0.43335462, 0.4332882, 3.3217628e-05]
Batch:8
d_loss:0.16538195907196496
g_loss:[0.43263936, 0.432574, 3.2680335e-05]
Batch:9
d_loss:0.17425962167180842
g_loss:[0.33808887, 0.33802482, 3.2016793e-05]
Batch:10
d_loss:0.16799646746949293
g_loss:[0.43496525, 0.43490237, 3.144233e-05]
Batch:11
d_loss:0.16339198776404373
g_loss:[0.34772474, 0.3476631, 3.081964e-05]
Batch:12
d_loss:0.168012012320105
g_loss:[0.3668112, 0.36675072, 3.0230007e-05]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch is: 365
Number of batches:37
Batch:1
d_loss:0.1640242152425344
g_loss:[3.3771582, 3.3770645, 4.6847337e-05]
Batch:2
d_loss:0.1649928536025982
g_loss:[2.1706219, 2.1705046, 5.861948e-05]
Batch:3
d_loss:0.16491150589718018
g_loss:[0.38354406, 0.38343298, 5.5539116e-05]
Batch:4
d_loss:0.5165464282035828
g_loss:[0.34732863, 0.34723508, 4.6774512e-05]
Batch:5
d_loss:0.16754779996699654
g_loss:[0.4658753, 0.46579784, 3.872659e-05]
Batch:6
d_loss:0.165437693009153
g_loss:[0.37567636, 0.37560844, 3.3960983e-05]
Batch:7
d_loss:0.16479866389272502
g_loss:[1.8200766, 1.8200136, 3.1465832e-05]
Batch:8
d_loss:0.16314384082215838
g_loss:[1.9423829, 1.9423218, 3.0598574e-05]
Batch:9
d_loss:0.16799453858584457
g_loss:[0.33589244, 0.3358283, 3.2068023e-05]
Batch:10
d_loss:0.16442680658656172
g_loss:[0.3339347, 0.3338703, 3.22014e-05]
Batch:11
d_loss:0.16393510217312723
g_loss:[0.34166864, 0.34160691, 3.08666e-05]
Batch:12
d_loss:0.1664619327057153
g_loss:[0.3728063, 0.3727473, 2.950684e-05]
Batch

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch is: 367
Number of batches:37
Batch:1
d_loss:0.16293889271037187
g_loss:[0.35577333, 0.35567954, 4.689407e-05]
Batch:2
d_loss:0.16407459370930155
g_loss:[2.2184446, 2.2183573, 4.3643588e-05]
Batch:3
d_loss:0.16362731642038852
g_loss:[0.3514554, 0.35134268, 5.6361354e-05]
Batch:4
d_loss:0.5475154966115952
g_loss:[2.269145, 2.2690363, 5.43021e-05]
Batch:5
d_loss:0.1637003914656816
g_loss:[0.3585026, 0.3584215, 4.0542804e-05]
Batch:6
d_loss:0.1645809012989048
g_loss:[0.35812795, 0.35805506, 3.644167e-05]
Batch:7
d_loss:0.16297876298631309
g_loss:[0.36689943, 0.36683214, 3.365139e-05]
Batch:8
d_loss:0.16326816059154226
g_loss:[0.8715017, 0.8714384, 3.16563e-05]
Batch:9
d_loss:0.16383260803922894
g_loss:[0.45270967, 0.4526358, 3.6943275e-05]
Batch:10
d_loss:0.16604398150229827
g_loss:[0.44538823, 0.445315, 3.6617465e-05]
Batch:11
d_loss:0.16459004523494514
g_loss:[0.56121105, 0.5611454, 3.2824475e-05]
Batch:12
d_loss:0.16450532737508183
g_loss:[0.34395126, 0.3438852, 3.3018136e-05]
Bat

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch is: 369
Number of batches:37
Batch:1
d_loss:0.16389095356862526
g_loss:[0.34783113, 0.34768552, 7.280565e-05]
Batch:2
d_loss:0.17249898951558862
g_loss:[0.36386508, 0.36372173, 7.168113e-05]
Batch:3
d_loss:0.1701392270042561
g_loss:[0.33224535, 0.3321062, 6.957503e-05]
Batch:4
d_loss:0.5333874672651291
g_loss:[0.37730595, 0.3771718, 6.7063185e-05]
Batch:5
d_loss:0.17179387372743804
g_loss:[0.36343688, 0.36330718, 6.484601e-05]
Batch:6
d_loss:0.16451705567305908
g_loss:[0.37767512, 0.3775493, 6.2909654e-05]
Batch:7
d_loss:0.16390151978703216
g_loss:[0.39325666, 0.39313412, 6.126999e-05]
Batch:8
d_loss:0.163273410347756
g_loss:[0.39449096, 0.39437172, 5.96193e-05]
Batch:9
d_loss:0.1669355758040183
g_loss:[0.36997885, 0.36986268, 5.80855e-05]
Batch:10
d_loss:0.16867018326593097
g_loss:[3.170896, 3.1707828, 5.663725e-05]
Batch:11
d_loss:0.16443422053271206
g_loss:[0.33536664, 0.33525568, 5.5479042e-05]
Batch:12
d_loss:0.1662110098041012
g_loss:[1.5759608, 1.5758523, 5.4232838e-05]
Ba

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch is: 371
Number of batches:37
Batch:1
d_loss:0.1642035718105035
g_loss:[0.35935456, 0.35931468, 1.993196e-05]
Batch:2
d_loss:0.1641832474124385
g_loss:[0.34525982, 0.3452208, 1.9505927e-05]
Batch:3
d_loss:0.1637133390904637
g_loss:[0.34950155, 0.34946334, 1.9097854e-05]
Batch:4
d_loss:0.617321103811264
g_loss:[3.2809885, 3.280951, 1.8693641e-05]
Batch:5
d_loss:0.17414153384743258
g_loss:[0.33777326, 0.33773664, 1.831605e-05]
Batch:6
d_loss:0.1647338735347148
g_loss:[0.36440387, 0.36436796, 1.795156e-05]
Batch:7
d_loss:0.16419554337335285
g_loss:[0.34138426, 0.34134907, 1.7599034e-05]
Batch:8
d_loss:0.16460960515541956
g_loss:[0.3666028, 0.36656836, 1.7230503e-05]
Batch:9
d_loss:0.16341242306225467
g_loss:[0.35521328, 0.35517952, 1.6885748e-05]
Batch:10
d_loss:0.1683953869069228
g_loss:[0.3427331, 0.34269997, 1.6558382e-05]
Batch:11
d_loss:0.16749731134041212
g_loss:[0.36724368, 0.36721125, 1.620535e-05]
Batch:12
d_loss:0.16318871587645845
g_loss:[0.33516335, 0.33513165, 1.585975e-

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch is: 373
Number of batches:37
Batch:1
d_loss:0.16408441487146774
g_loss:[0.3511138, 0.35091013, 0.00010183331]
Batch:2
d_loss:0.16485256554551597
g_loss:[0.33073002, 0.33052087, 0.00010457945]
Batch:3
d_loss:0.16353590508151683
g_loss:[0.34015757, 0.33995593, 0.00010082046]
Batch:4
d_loss:0.5819201320409775
g_loss:[0.36061454, 0.36042598, 9.42755e-05]
Batch:5
d_loss:0.17231954681847128
g_loss:[0.36088148, 0.36070707, 8.720229e-05]
Batch:6
d_loss:0.16556629218393937
g_loss:[1.0823979, 1.0822351, 8.1432285e-05]
Batch:7
d_loss:0.16310452322795754
g_loss:[1.4031794, 1.4029359, 0.000121760735]
Batch:8
d_loss:0.16455151018453762
g_loss:[0.3327025, 0.33215502, 0.00027373695]
Batch:9
d_loss:0.16405430692975642
g_loss:[0.34720603, 0.3465755, 0.00031526174]
Batch:10
d_loss:0.1637687288748566
g_loss:[0.36032102, 0.35981983, 0.00025059836]
Batch:11
d_loss:0.16575239073790726
g_loss:[0.33664775, 0.33629757, 0.00017509202]
Batch:12
d_loss:0.1690066999217379
g_loss:[0.32983553, 0.32958096, 0.000

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Epoch is: 375
Number of batches:37
Batch:1
d_loss:3.301663190126419
g_loss:[2.4326308, 2.4325707, 3.0090789e-05]
Batch:2
d_loss:2.503634035587311
g_loss:[1.1202862, 1.120218, 3.412009e-05]
Batch:3
d_loss:1.483655035495758
g_loss:[2.362594, 2.3623228, 0.00013554066]
Batch:4
d_loss:1.977427363395691
g_loss:[3.5789752, 3.5783913, 0.000291896]
Batch:5
d_loss:1.5350624024868011
g_loss:[2.9111736, 2.9108613, 0.00015611073]
Batch:6
d_loss:0.6340857171453536
g_loss:[7.0484757, 7.048261, 0.00010728673]
Batch:7
d_loss:0.531386524438858
g_loss:[10.559972, 10.559793, 8.9223606e-05]
Batch:8
d_loss:0.8648100793361664
g_loss:[7.951413, 7.951076, 0.00016850642]
Batch:9
d_loss:0.47315138578414917
g_loss:[6.3540273, 6.353676, 0.00017567484]
Batch:10
d_loss:0.44996051490306854
g_loss:[7.071651, 7.0713167, 0.0001672201]
Batch:11
d_loss:0.8327319324016571
g_loss:[7.84013, 7.839649, 0.00024043182]
Batch:12
d_loss:1.1753108203411102
g_loss:[0.9974139, 0.9968907, 0.00026158572]
Batch:13
d_loss:2.4014017879962