# Competition 3: Team 21

112062649 王俊皓

112062650 廖士傑

##  Reverse Image Caption

In [252]:
# project settings
# enc / gen / dis: True means using a more complex setting, False means using a setting that is closer to template
# enc_do_batchnorm: encoder will perform batch normalization to make the text encoding more variant
# dis_backbone: the type of backbone in discriminator, can be 'resnet', 'vgg', 'simple'
# delete_checkpoint: deletes all checkpoint files before training

class experimental_settings:
    def __init__(self,
                 enc=True,
                 enc_do_batchnorm=False,
                 gen=True,
                 dis=True,
                 dis_backbone='simple',
                 delete_checkpoint=False):
        self.enc = enc
        self.gen = gen
        self.dis = dis
        self.dis_backbone = dis_backbone
        self.enc_do_batchnorm = enc_do_batchnorm
        self.delete_checkpoint = delete_checkpoint # not implemented yet
        
        # ============================ #
        # automatic
        # ============================ #
        
        self.caption_type = 'sentence' if self.enc else 'id'


expSettings = experimental_settings()

## Import

In [253]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras import layers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import string
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import random
import time
from pathlib import Path
import math

import re
from IPython import display

GPU check

In [254]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Restrict TensorFlow to only use the first GPU
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


## Loading Data

In [255]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))

there are 5427 vocabularies in total
Word to id mapping, for example: flower -> 1
Id to word mapping, for example: 1 -> flower
Tokens: <PAD>: 5427; <RARE>: 5428


In [256]:
def sent2IdList(line, MAX_SEQ_LENGTH=20):
    MAX_SEQ_LIMIT = MAX_SEQ_LENGTH
    padding = 0
    
    # data preprocessing, remove all puntuation in the texts
    prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('  ', ' ')
    prep_line = prep_line.replace('.', '')
    tokens = prep_line.split(' ')
    tokens = [
        tokens[i] for i in range(len(tokens))
        if tokens[i] != ' ' and tokens[i] != ''
    ]
    l = len(tokens)
    padding = MAX_SEQ_LIMIT - l
    
    # make sure length of each text is equal to MAX_SEQ_LENGTH, and replace the less common word with <RARE> token
    for i in range(padding):
        tokens.append('<PAD>')
    line = [
        word2Id_dict[tokens[k]]
        if tokens[k] in word2Id_dict else word2Id_dict['<RARE>']
        for k in range(len(tokens))
    ]

    return line

text = "the flower shown has yellow anther red pistil and bright red petals."
print(text)
print(sent2IdList(text))

the flower shown has yellow anther red pistil and bright red petals.
['9', '1', '82', '5', '11', '70', '20', '31', '3', '29', '20', '2', '5427', '5427', '5427', '5427', '5427', '5427', '5427', '5427']


In [257]:
@tf.function
def id2Sent(ids):
    return " ".join([id2word_dict[idx] for idx in ids]).strip()

#def batch_id2Sent(batch_ids):
    #return [id2Sent(ids) for ids in batch_ids]
    
def batch_id2Sent(batch_ids):
    def process_single(ids):
        # Convert a single tensor of IDs to a sentence
        ids = ids.numpy()  # Convert Tensor to NumPy
        sentence = " ".join([id2word_dict.get(idx, "<UNK>") for idx in ids])  # Handle unknown IDs
        return sentence

    # Use tf.py_function to apply Python function inside the TensorFlow graph
    sentences = tf.map_fn(
        lambda ids: tf.py_function(process_single, [ids], tf.string),
        batch_ids,
        fn_output_signature=tf.string
    )
    return sentences


print(sent2IdList(text))
print(id2Sent(sent2IdList(text)))

['9', '1', '82', '5', '11', '70', '20', '31', '3', '29', '20', '2', '5427', '5427', '5427', '5427', '5427', '5427', '5427', '5427']
tf.Tensor(b'the flower shown has yellow anther red pistil and bright red petals <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>', shape=(), dtype=string)


In [258]:
data_path = './dataset'
text2ImgData = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(text2ImgData)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))

There are 7370 image in training data


In [259]:
def caption2string(cap):
    output = []
    for sen in cap:
        s = " ".join([id2word_dict[idx] for idx in sen]).strip()
        output.append(s.split(' <PAD>')[0])
    return output

# adding caption as strings
text2ImgData['Captions_string'] = text2ImgData['Captions'].apply(caption2string)

In [260]:
text2ImgData.head(5)

Unnamed: 0_level_0,Captions,ImagePath,Captions_string
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
6734,"[[9, 2, 17, 9, 1, 6, 14, 13, 18, 3, 41, 8, 11,...",./102flowers/image_06734.jpg,[the petals of the flower are pink in color an...
6736,"[[4, 1, 5, 12, 2, 3, 11, 31, 28, 68, 106, 132,...",./102flowers/image_06736.jpg,[this flower has white petals and yellow pisti...
6737,"[[9, 2, 27, 4, 1, 6, 14, 7, 12, 19, 5427, 5427...",./102flowers/image_06737.jpg,[the petals on this flower are pink with white...
6738,"[[9, 1, 5, 8, 54, 16, 38, 7, 12, 116, 325, 3, ...",./102flowers/image_06738.jpg,[the flower has a smooth purple petal with whi...
6739,"[[4, 12, 1, 5, 29, 11, 19, 7, 26, 70, 5427, 54...",./102flowers/image_06739.jpg,[this white flower has bright yellow stamen wi...


In [261]:
text2ImgData['Captions_string'][:1].tolist()

[['the petals of the flower are pink in color and have a yellow center',
  'this flower is pink and white in color with petals that are multi colored',
  'the purple petals have shades of white with white anther and filament',
  'this flower has large pink petals and a white stigma in the center',
  'this flower has petals that are pink and has a yellow stamen',
  'a flower with short and wide petals that is light purple',
  'this flower has small pink petals with a yellow center',
  'this flower has large rounded pink petals with curved edges and purple veins',
  'this flower has purple petals as well as a white stamen']]

In [262]:
# in this competition, you have to generate image in size 64x64x3
IMAGE_SIZE = 64
IMAGE_HEIGHT = IMAGE_SIZE
IMAGE_WIDTH = IMAGE_SIZE
IMAGE_CHANNEL = 3

def training_data_generator(caption, image_path, caption_type='id'):
    # load in the image according to image path
    img = tf.io.read_file(image_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img.set_shape([None, None, 3])
    img = tf.image.resize(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
    img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    if caption_type == 'id':
        caption = tf.cast(caption, tf.int32)
    elif caption_type == 'sentence':
        caption = tf.convert_to_tensor(caption, dtype=tf.string)

    return img, caption

def dataset_generator(filenames, batch_size, data_generator, caption_type='id'):
    # load the training data into two NumPy arrays
    if filenames != None:
        df = pd.read_pickle(filenames)
    else:
        df = text2ImgData
    
    if caption_type == 'id':
        captions = df['Captions'].values
    elif caption_type == 'sentence':
        captions = df['Captions_string'].values
    else:
        raise ValueError('for dataset_generator, caption_type= should be \'id\' or \'sentence\'.')
        
    caption = []
    # each image has 1 to 10 corresponding captions
    # we choose one of them randomly for training
    
    # ============================================ #
    # TODO: augmentation
    # idea 1 (difficulty: easy)
    #     training data has multiple captions, right now it picks a random one.
    #     we can make it so that every caption is an entry and multiple captions link to the same image.
    # idea 2 (difficulty: medium)
    #     after text embedding, use the average of 2 caption embeddings to generate a new caption.
    #     the data does not need to have an image tied to it, it just have the label 0 (fake image).
    # ============================================ #
    for i in range(len(captions)):
        caption.append(random.choice(captions[i]))
    caption = np.asarray(caption)
    
    if caption_type == 'id':
        caption = caption.astype(np.int)
        
    image_path = df['ImagePath'].values
    
    # assume that each row of `features` corresponds to the same row as `labels`.
    assert caption.shape[0] == image_path.shape[0]
    
    datagen_func = lambda cap, img: data_generator(cap, img, caption_type=caption_type)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, image_path))
    dataset = dataset.map(datagen_func, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(len(caption)).batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

In [263]:
BATCH_SIZE = 64
dataset = dataset_generator(
    #data_path + '/text2ImgData.pkl',
    None,
    BATCH_SIZE, 
    training_data_generator, 
    caption_type=expSettings.caption_type)

In [264]:
# dataset testing ground
for img, cap in dataset.take(1):
    print("Image shape:", img.numpy().shape)
    print("Caption shape:", cap.numpy().shape)

Image shape: (64, 64, 64, 3)
Caption shape: (64,)


In [265]:
from tensorflow.keras.layers import Conv2DTranspose, Conv2D, BatchNormalization, LeakyReLU, Dense, Dropout
from tensorflow.keras.initializers import HeNormal

# custom layers
class flattened_dense(tf.keras.layers.Layer):
    """
    a dense layer that is made compatible with convolution layers
    by flattening the input first and followed by a dense layer.
    """
    def __init__(self, channels=64, kernel_initializer="glorot_uniform"):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(channels, kernel_initializer=kernel_initializer)
        
    def call(self, inputs):
        fl = self.flatten(inputs)
        return self.dense(fl)
    
class conv_block(tf.keras.layers.Layer):
    """
    a convolution layer with batch normalization and leaky relu activation
    """
    def __init__(self, filters=128, kernel_size=1, strides=1, kernel_initializer=HeNormal()):
        super().__init__()
        self.conv = Conv2D(filters=filters,
                           kernel_size = (kernel_size, kernel_size),
                           strides=(strides, strides),
                           padding='same',
                           kernel_initializer=kernel_initializer)
        self.bn = BatchNormalization()
        self.activation = LeakyReLU(alpha=0.1)
    
    def call(self, inputs):
        return self.activation(self.bn(self.conv(inputs)))
    
class deconv_block(tf.keras.layers.Layer):
    """
    a deconvolution layer with batch normalization and leaky relu activation
    """
    def __init__(self, filters=128, kernel_size=4, strides=2, kernel_initializer='glorot_uniform'):
        super().__init__()
        self.deconv = Conv2DTranspose(filters=filters,
                                    kernel_size = (kernel_size, kernel_size),
                                    strides=(strides, strides),
                                    padding='same',
                                    kernel_initializer=kernel_initializer)
        self.bn = BatchNormalization()
        self.activation = LeakyReLU(alpha=0.1)
    
    def call(self, inputs):
        return self.activation(self.bn(self.deconv(inputs)))
    
class dense_block(tf.keras.layers.Layer):
    """
    a dense layer with batch normalization and leaky relu activation
    """
    def __init__(self, filters=128, kernel_initializer='glorot_uniform'):
        super().__init__()
        self.d = Dense(filters, kernel_initializer=kernel_initializer)
        self.bn = BatchNormalization()
        self.activation = LeakyReLU(alpha=0.1)
    
    def call(self, inputs):
        outputs = self.d(inputs)
        outputs = self.bn(outputs)
        return self.activation(outputs)

class noise_layer(tf.keras.layers.Layer):
    """
    a layer that adds noise
    """
    def __init__(self, mean=0, stddev=0.01):
        super().__init__()
        self.m = mean
        self.s = stddev
    
    def call(self, inputs):
        noise = tf.random.normal(list(inputs.shape), self.m, self.s)
        outputs = inputs + noise
        return output

In [266]:
import tensorflow_hub as hub

class TextEncoder(tf.keras.Model):
    """
    Encode text (a caption) into hidden representation
    input: text, which is a list of ids
    output: embedding, or hidden representation of input text in dimension of RNN_HIDDEN_SIZE
    """
    def __init__(self, hparas, experimental=False, do_batchnorm=False):
        super(TextEncoder, self).__init__()
        self.exp=experimental
        self.do_batchnorm = do_batchnorm
        self.hparas = hparas
        self.batch_size = self.hparas['BATCH_SIZE']
        
        # embedding with tensorflow API
        self.embedding = layers.Embedding(self.hparas['VOCAB_SIZE'], self.hparas['EMBED_DIM'])
        # RNN, here we use GRU cell, another common RNN cell similar to LSTM
        self.gru = layers.GRU(self.hparas['RNN_HIDDEN_SIZE'],
                              return_sequences=True,
                              return_state=True,
                              recurrent_initializer='glorot_uniform')
        if self.exp:
            self.embed = hub.load('./checkpoints/universal_sentence_encoder')
    
    def call(self, text, hidden):
        if self.exp:
            with tf.device('/CPU:0'): # TODO if you find a way to use GPU, go for it.
                output_last = self.embed(text)
                
            state = hidden # not updating state for compatibility reasons
            
        else:
            text = self.embedding(text)
            output, state = self.gru(text, initial_state = hidden)
            output_last = output[:, -1, :]
        
        # normalization in-batch
        if self.do_batchnorm:
            mean = tf.reduce_mean(output_last, axis=0, keepdims=True)  # Mean across the batch
            std = tf.math.reduce_std(output_last, axis=0, keepdims=True)  # Std across the batch
            normalized = (output_last - mean) / (std + 1e-6)  # Avoid division by zero
        else:
            normalized = output_last
        
        return normalized, state
    
    def initialize_hidden_state(self):
        return tf.zeros((self.hparas['BATCH_SIZE'], self.hparas['RNN_HIDDEN_SIZE']))

In [267]:
class Generator(tf.keras.Model):
    """
    Generate fake image based on given text(hidden representation) and noise z
    input: text and noise
    output: fake image with size 64*64*3
    """
    def __init__(self, hparas, experimental=False):
        super(Generator, self).__init__()
        self.exp = experimental
        self.hparas = hparas
        self.flatten = tf.keras.layers.Flatten()
        self.d1 = tf.keras.layers.Dense(self.hparas['DENSE_DIM'], kernel_initializer="glorot_uniform")
        self.d2 = tf.keras.layers.Dense(64*64*3, kernel_initializer="glorot_uniform")
        if self.exp:
            self.deconv_depth = int(math.log(IMAGE_SIZE, 2)) - 1
            self.starter = dense_block(filters=2*2*512)
            self.deconv = [
                deconv_block(filters=512, kernel_initializer=HeNormal()),
                #conv_block(filters=512, kernel_size=1, strides=1),
                deconv_block(filters=256, kernel_initializer=HeNormal()),
                #conv_block(filters=256, kernel_size=1, strides=1),
                #Dropout(0.2),
                deconv_block(filters=128, kernel_initializer=HeNormal()),
                #conv_block(filters=128, kernel_size=1, strides=1),
                #Dropout(0.2),
                #noise_layer(),
                deconv_block(filters=64, kernel_initializer=HeNormal()),
                #conv_block(filters=64, kernel_size=1, strides=1),
                deconv_block(filters=32, kernel_initializer=HeNormal()),
                #noise_layer(),
                conv_block(filters=32, kernel_size=1, strides=1)
            ]
            self.headf = conv_block(filters=3, kernel_size=1, strides=1)
            
    def call(self, text, noise_z, debug_output=False, training=False):
        # deconvolution
        if self.exp:
            noisy_text = tf.concat([text, noise_z], axis=1) * 10 # amplify the input a bit, they seem fairly close to 0.
            img = self.starter(noisy_text)
            
            img = tf.reshape(img, [-1, 2, 2, 512])
            debug = []
            for layer in self.deconv:
                if isinstance(layer, tf.keras.layers.Dropout) and training:
                    continue
                debug.append(img)
                img = layer(img)

            img = self.headf(img)
            logits = tf.reshape(img, [-1, IMAGE_SIZE, IMAGE_SIZE, 3])
            output = tf.nn.tanh(logits)
        
        # concatenate input text and random noise
        else:
            text = self.flatten(text)
            text = self.d1(text)
            text = tf.nn.leaky_relu(text)
            text_concat = tf.concat([noise_z, text], axis=1)
            text_concat = self.d2(text_concat)
        
            logits = tf.reshape(text_concat, [-1, 64, 64, 3])
            output = tf.nn.tanh(logits)
            debug_output = output
        
        if debug_output:
            return logits, output, debug
        else:
            return logits, output

In [268]:
from tensorflow.keras.applications import ResNet50, VGG16

class Discriminator(tf.keras.Model):
    """
    Differentiate the real and fake image
    input: image and corresponding text
    output: labels, the real image should be 1, while the fake should be 0
    """
    def __init__(self, hparas, experimental=False, backbone='simple'):
        super(Discriminator, self).__init__()
        self.exp = experimental
        self.bbtype = backbone
        self.hparas = hparas
        if self.exp:
            if self.bbtype == 'resnet':
                self.resnet_base = ResNet50(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), weights='imagenet', include_top=False)
                for layer in self.resnet_base.layers:
                    layer.trainable = False
            elif self.bbtype == 'vgg':
                self.vgg_base = VGG16(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), weights='imagenet', include_top=False)
                for layer in self.vgg_base.layers:
                    layer.trainable = False
            elif self.bbtype == 'simple':
                self.conv = [
                    conv_block(filters=512, kernel_size=3, strides=2),
                    #conv_block(filters=512, kernel_size=3, strides=1),
                    conv_block(filters=512, kernel_size=1, strides=1),
                    Dropout(0.5),
                    conv_block(filters=256, kernel_size=3, strides=2),
                    #conv_block(filters=256, kernel_size=3, strides=1),
                    conv_block(filters=256, kernel_size=1, strides=1),
                    Dropout(0.5),
                    conv_block(filters=128, kernel_size=3, strides=2),
                    #conv_block(filters=128, kernel_size=3, strides=1),
                    conv_block(filters=128, kernel_size=1, strides=1),
                    conv_block(filters=64, kernel_size=3, strides=2),
                    #conv_block(filters=64, kernel_size=3, strides=1),
                    conv_block(filters=64, kernel_size=1, strides=1)
                ]
                # more simple one
                #self.conv1 = conv_block(filters=256, kernel_size=3, strides=1)
                #self.conv2 = conv_block(filters=64, kernel_size=3, strides=1)
        self.flatten = tf.keras.layers.Flatten()
        self.d_text = tf.keras.layers.Dense(self.hparas['DENSE_DIM'])
        self.d_img = tf.keras.layers.Dense(self.hparas['DENSE_DIM'])
        self.d = tf.keras.layers.Dense(1)
    
    def call(self, img, text, training=False):
        text = self.flatten(text)
        text = self.d_text(text)
        text = tf.nn.leaky_relu(text)
        
        if self.exp:
            if self.bbtype == 'resnet':
                img = self.resnet_base(img)
            elif self.bbtype == 'vgg':
                img = self.vgg_base(img)
            elif self.bbtype == 'simple':
                for layer in self.conv: # see init for spec
                    if isinstance(layer, Dropout) and training:
                        continue
                    img = layer(img)
        img = self.flatten(img)
        img = self.d_img(img)
        img = tf.nn.leaky_relu(img)
        
        # concatenate image with paired text
        img_text = tf.concat([text, img], axis=1)
        
        logits = self.d(img_text)
        output = tf.nn.sigmoid(logits)
        
        return logits, output

Parameters and settings

In [269]:
hparas = {
    'MAX_SEQ_LENGTH': 20,                     # maximum sequence length
    'EMBED_DIM': 256,                         # word embedding dimension
    'VOCAB_SIZE': len(word2Id_dict),          # size of dictionary of captions
    'RNN_HIDDEN_SIZE': 128,                   # number of RNN neurons
    'Z_DIM': 512,                             # random noise z dimension
    'DENSE_DIM': 128,                         # number of neurons in dense layer
    'IMAGE_SIZE': [64, 64, 3],                # render image size
    'BATCH_SIZE': 64,
    'LR_GEN': 1e-4,
    'LR_DIS': 1e-5,
    'LR_DECAY': 0.5,                          # unused
    'BETA_1': 0.5,
    'N_EPOCH': 1000,                            # number of epoch for demo
    'N_SAMPLE': num_training_sample,          # size of training data
    'CHECKPOINTS_DIR': './checkpoints/demo',  # checkpoint path
    'PRINT_FREQ': 1                           # printing frequency of loss
}

In [270]:
text_encoder = TextEncoder(hparas, 
                           experimental=expSettings.enc,
                           do_batchnorm=expSettings.enc_do_batchnorm)

generator = Generator(hparas,
                      experimental=expSettings.gen)

discriminator = Discriminator(hparas,
                              experimental=expSettings.dis,
                              backbone=expSettings.dis_backbone)

In [271]:
# test text encoder
for img, cap in dataset.take(1):
    print("Image shape:", img.numpy().shape)
    print("Caption shape:", cap.shape)
    with tf.device('/CPU:0'):
        output, _ = text_encoder(cap, text_encoder.initialize_hidden_state())
        print("Caption embed shape:", output.shape)

Image shape: (64, 64, 64, 3)
Caption shape: (64,)
Caption embed shape: (64, 512)


In [272]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [273]:
from tensorflow.nn import sigmoid_cross_entropy_with_logits

def discriminator_loss(real_logits, fake_logits):
    # output value of real image should be 1
    real_loss = sigmoid_cross_entropy_with_logits(tf.ones_like(real_logits), real_logits)
    # output value of fake image should be 0
    fake_loss = sigmoid_cross_entropy_with_logits(tf.zeros_like(fake_logits), fake_logits)
    total_loss = tf.reduce_mean(real_loss) + tf.reduce_mean(fake_loss)
                                 
    return total_loss
def generator_loss(fake_output):
    # output value of fake image should be 0
    return tf.reduce_mean(sigmoid_cross_entropy_with_logits(tf.ones_like(fake_output), fake_output))

In [274]:
# we use seperated optimizers for training generator and discriminator
generator_optimizer = tf.keras.optimizers.Adam(hparas['LR_GEN'], clipvalue=0.1)
discriminator_optimizer = tf.keras.optimizers.Adam(hparas['LR_DIS'], clipvalue=0.1)

In [275]:
if expSettings.delete_checkpoint:
    for f in os.listdir(hparas['CHECKPOINTS_DIR']):
        file_path = os.path.join(hparas['CHECKPOINTS_DIR'], f)
        if os.path.isfile(file_path):
            os.unlink(file_path)

# one benefit of tf.train.Checkpoint() API is we can save everything seperately
checkpoint_dir = hparas['CHECKPOINTS_DIR']
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 text_encoder=text_encoder,
                                 generator=generator,
                                 discriminator=discriminator)
ckptManager = tf.train.CheckpointManager(
    checkpoint, directory=hparas['CHECKPOINTS_DIR'], max_to_keep=5)

In [276]:
@tf.function
def train_step(real_image, caption, hidden, imshow=False):
    # random noise for generator
    noise = tf.random.normal(shape=[hparas['BATCH_SIZE'], hparas['Z_DIM']], mean=0.0, stddev=0.1)
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        text_embed, hidden = text_encoder(caption, hidden)
        _, fake_image = generator(text_embed, noise, training=True)
        if imshow:
            plt.imshow(fake_image[0])

        real_logits, real_output = discriminator(real_image, text_embed, training=True)
        fake_logits, fake_output = discriminator(fake_image, text_embed, training=True)

        g_loss = generator_loss(fake_logits)
        d_loss = discriminator_loss(real_logits, fake_logits)

    grad_g = gen_tape.gradient(g_loss, generator.trainable_variables)
    grad_d = disc_tape.gradient(d_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(grad_g, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(grad_d, discriminator.trainable_variables))
    
    return g_loss, d_loss

In [277]:
@tf.function
def test_step(caption, noise, hidden):
    text_embed, hidden = text_encoder(caption, hidden)
    _, fake_image = generator(text_embed, noise)
    return fake_image

Sample Debugging (unused)

In [278]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image
    return img

def imsave(images, size, path):
    # getting the pixel values between [0, 1] to save it
    return plt.imsave(path, merge(images, size)*0.5 + 0.5)

def save_images(images, size, image_path):
    return imsave(images, size, image_path)

In [279]:
def sample_generator(caption, batch_size, caption_type='id'):
    if caption_type == 'sentence':
        caption = caption2string(caption)
    caption = np.asarray(caption)
    if caption_type == 'id':
        caption = caption.astype(np.int)
    dataset = tf.data.Dataset.from_tensor_slices(caption)
    dataset = dataset.batch(batch_size)
    return dataset

In [280]:
ni = int(np.ceil(np.sqrt(hparas['BATCH_SIZE'])))
sample_size = hparas['BATCH_SIZE']
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)
sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * int(sample_size/ni) + \
                  ["this flower has petals that are yellow, white and purple and has dark lines"] * int(sample_size/ni) + \
                  ["the petals on this flower are white with a yellow center"] * int(sample_size/ni) + \
                  ["this flower has a lot of small round pink petals."] * int(sample_size/ni) + \
                  ["this flower is orange in color, and has petals that are ruffled and rounded."] * int(sample_size/ni) + \
                  ["the flower has yellow petals and the center of it is brown."] * int(sample_size/ni) + \
                  ["this flower has petals that are blue and white."] * int(sample_size/ni) +\
                  ["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/ni)

for i, sent in enumerate(sample_sentence):
    sample_sentence[i] = sent2IdList(sent)
sample_sentence = sample_generator(sample_sentence, hparas['BATCH_SIZE'], caption_type=expSettings.caption_type)

In [281]:
# test the sample dataset
for cap in sample_sentence.take(1):
    print("Caption shape:", cap.numpy().shape)
    emb, _ = text_encoder(cap, text_encoder.initialize_hidden_state())
    print("Caption embeddings:", emb)

Caption shape: (64,)
Caption embeddings: tf.Tensor(
[[-0.01346336  0.06881763 -0.05529514 ... -0.01518281 -0.00633275
   0.01500697]
 [-0.01346336  0.06881763 -0.05529514 ... -0.01518281 -0.00633275
   0.01500697]
 [-0.01346336  0.06881763 -0.05529514 ... -0.01518281 -0.00633275
   0.01500697]
 ...
 [-0.00790838  0.06233827  0.01107207 ... -0.02398296 -0.03450323
  -0.00868433]
 [-0.00790838  0.06233827  0.01107207 ... -0.02398296 -0.03450323
  -0.00868433]
 [-0.00790838  0.06233827  0.01107207 ... -0.02398296 -0.03450323
  -0.00868433]], shape=(64, 512), dtype=float32)


In [282]:
if not os.path.exists('samples/demo'):
    os.makedirs('samples/demo')

Training and testing

In [283]:
def train(dataset, epochs):
    # hidden state of RNN
    hidden = text_encoder.initialize_hidden_state()
    steps_per_epoch = int(hparas['N_SAMPLE']/hparas['BATCH_SIZE'])
    
    lowest_gen_loss = 1e10
    
    for epoch in range(hparas['N_EPOCH']):
        g_total_loss = 0
        d_total_loss = 0
        batchid = 0
        start = time.time()
        imshow = False
        
        for image, caption in dataset:
            batchid += 1
            g_loss, d_loss = train_step(image, caption, hidden, imshow=imshow)
            imshow = False
            g_total_loss += g_loss
            d_total_loss += d_loss
        
        time_tuple = time.localtime()
        time_string = time.strftime("%m/%d/%Y, %H:%M:%S", time_tuple)
            
        print("Epoch {}, gen_loss: {:.4f}, disc_loss: {:.4f}".format(epoch+1,
                                                                     g_total_loss/steps_per_epoch,
                                                                     d_total_loss/steps_per_epoch))
        print('Time for epoch {} is {:.4f} sec'.format(epoch+1, time.time()-start))
        
        # save the model if lower gen loss is achieved
        if g_total_loss < lowest_gen_loss:
            lowest_gen_loss = g_total_loss
            ckptManager.save()
            print('new lowest. saving model.')
        elif g_total_loss < 3 * lowest_gen_loss:
            ckptManager.save()
            print('within save threshold. saving model.')
        else:
            lowest_gen_loss *= 1.02
            
        print('======================================')
        
        # visualization
        if (epoch + 1) % hparas['PRINT_FREQ'] == 0:
            for caption in sample_sentence:
                fake_image = test_step(caption, sample_seed, hidden)
            save_images(fake_image,
                        [ni, ni],
                        'samples/demo/train_{:02d}.jpg'.format(epoch))

In [284]:
train(dataset, hparas['N_EPOCH'])

  "shape. This may consume a large amount of memory." % value)


Epoch 1, gen_loss: 0.1284, disc_loss: 3.3784
Time for epoch 1 is 27.7097 sec
new lowest. saving model.
Epoch 2, gen_loss: 0.0845, disc_loss: 3.4380
Time for epoch 2 is 21.2816 sec
new lowest. saving model.
Epoch 3, gen_loss: 0.0875, disc_loss: 3.3005
Time for epoch 3 is 20.9237 sec
within save threshold. saving model.
Epoch 4, gen_loss: 0.1053, disc_loss: 3.0467
Time for epoch 4 is 20.5971 sec
within save threshold. saving model.
Epoch 5, gen_loss: 0.1471, disc_loss: 2.6813
Time for epoch 5 is 20.7887 sec
within save threshold. saving model.
Epoch 6, gen_loss: 0.2270, disc_loss: 2.2334
Time for epoch 6 is 20.8328 sec
within save threshold. saving model.
Epoch 7, gen_loss: 0.3566, disc_loss: 1.8722
Time for epoch 7 is 20.9663 sec
Epoch 8, gen_loss: 0.7347, disc_loss: 1.2040
Time for epoch 8 is 20.7251 sec
Epoch 9, gen_loss: 1.1386, disc_loss: 0.6872
Time for epoch 9 is 20.3112 sec
Epoch 10, gen_loss: 1.6893, disc_loss: 0.4304
Time for epoch 10 is 21.1583 sec
Epoch 11, gen_loss: 2.0646, 

Epoch 69, gen_loss: 4.6061, disc_loss: 0.0265
Time for epoch 69 is 20.5390 sec
Epoch 70, gen_loss: 4.8323, disc_loss: 0.0204
Time for epoch 70 is 20.9139 sec
Epoch 71, gen_loss: 4.9794, disc_loss: 0.0170
Time for epoch 71 is 20.6760 sec
Epoch 72, gen_loss: 5.6523, disc_loss: 0.0092
Time for epoch 72 is 21.3417 sec
Epoch 73, gen_loss: 5.9206, disc_loss: 0.0066
Time for epoch 73 is 20.4680 sec
Epoch 74, gen_loss: 5.7473, disc_loss: 0.0071
Time for epoch 74 is 20.6460 sec
Epoch 75, gen_loss: 5.6143, disc_loss: 0.0083
Time for epoch 75 is 20.8871 sec
Epoch 76, gen_loss: 5.8806, disc_loss: 0.0070
Time for epoch 76 is 20.9470 sec
Epoch 77, gen_loss: 6.0536, disc_loss: 0.0057
Time for epoch 77 is 20.4525 sec
Epoch 78, gen_loss: 6.0497, disc_loss: 0.0053
Time for epoch 78 is 21.1631 sec
Epoch 79, gen_loss: 5.8313, disc_loss: 0.0064
Time for epoch 79 is 20.8635 sec
Epoch 80, gen_loss: 5.9285, disc_loss: 0.0059
Time for epoch 80 is 20.4034 sec
Epoch 81, gen_loss: 5.1887, disc_loss: 0.0125
Time f

Epoch 133, gen_loss: 1.1335, disc_loss: 1.4498
Time for epoch 133 is 21.1792 sec
within save threshold. saving model.
Epoch 134, gen_loss: 1.4859, disc_loss: 0.9185
Time for epoch 134 is 20.7825 sec
within save threshold. saving model.
Epoch 135, gen_loss: 1.1529, disc_loss: 1.1007
Time for epoch 135 is 21.0880 sec
within save threshold. saving model.
Epoch 136, gen_loss: 1.1289, disc_loss: 1.1617
Time for epoch 136 is 20.6531 sec
within save threshold. saving model.
Epoch 137, gen_loss: 1.0073, disc_loss: 1.2800
Time for epoch 137 is 21.2845 sec
within save threshold. saving model.
Epoch 138, gen_loss: 1.1280, disc_loss: 1.1722
Time for epoch 138 is 21.0310 sec
within save threshold. saving model.
Epoch 139, gen_loss: 1.1426, disc_loss: 1.0851
Time for epoch 139 is 21.0789 sec
within save threshold. saving model.
Epoch 140, gen_loss: 1.2164, disc_loss: 0.9615
Time for epoch 140 is 21.1192 sec
within save threshold. saving model.
Epoch 141, gen_loss: 1.1551, disc_loss: 1.1099
Time for 

Epoch 186, gen_loss: 0.7560, disc_loss: 1.4722
Time for epoch 186 is 20.8358 sec
within save threshold. saving model.
Epoch 187, gen_loss: 0.7207, disc_loss: 1.5089
Time for epoch 187 is 20.8620 sec
within save threshold. saving model.
Epoch 188, gen_loss: 0.6877, disc_loss: 1.5495
Time for epoch 188 is 20.6146 sec
within save threshold. saving model.
Epoch 189, gen_loss: 0.6790, disc_loss: 1.5841
Time for epoch 189 is 21.0043 sec
within save threshold. saving model.
Epoch 190, gen_loss: 0.6668, disc_loss: 1.6028
Time for epoch 190 is 20.6026 sec
within save threshold. saving model.
Epoch 191, gen_loss: 0.6570, disc_loss: 1.5870
Time for epoch 191 is 20.8110 sec
within save threshold. saving model.
Epoch 192, gen_loss: 0.6648, disc_loss: 1.6545
Time for epoch 192 is 20.3922 sec
within save threshold. saving model.
Epoch 193, gen_loss: 0.7423, disc_loss: 1.4017
Time for epoch 193 is 21.2582 sec
within save threshold. saving model.
Epoch 194, gen_loss: 0.7420, disc_loss: 1.4985
Time for 

within save threshold. saving model.
Epoch 239, gen_loss: 0.7094, disc_loss: 1.4324
Time for epoch 239 is 20.6060 sec
within save threshold. saving model.
Epoch 240, gen_loss: 0.7136, disc_loss: 1.3824
Time for epoch 240 is 21.0482 sec
within save threshold. saving model.
Epoch 241, gen_loss: 0.7452, disc_loss: 1.3237
Time for epoch 241 is 21.1610 sec
within save threshold. saving model.
Epoch 242, gen_loss: 0.7253, disc_loss: 1.3536
Time for epoch 242 is 21.3688 sec
within save threshold. saving model.
Epoch 243, gen_loss: 0.6924, disc_loss: 1.4206
Time for epoch 243 is 20.8970 sec
within save threshold. saving model.
Epoch 244, gen_loss: 0.7029, disc_loss: 1.4221
Time for epoch 244 is 20.8799 sec
within save threshold. saving model.
Epoch 245, gen_loss: 0.7284, disc_loss: 1.3574
Time for epoch 245 is 20.6661 sec
within save threshold. saving model.
Epoch 246, gen_loss: 0.6933, disc_loss: 1.4077
Time for epoch 246 is 20.6609 sec
within save threshold. saving model.
Epoch 247, gen_loss

Epoch 291, gen_loss: 0.8226, disc_loss: 1.3584
Time for epoch 291 is 20.7830 sec
within save threshold. saving model.
Epoch 292, gen_loss: 0.8316, disc_loss: 1.2835
Time for epoch 292 is 20.8381 sec
within save threshold. saving model.
Epoch 293, gen_loss: 0.8011, disc_loss: 1.3857
Time for epoch 293 is 20.8906 sec
within save threshold. saving model.
Epoch 294, gen_loss: 0.7860, disc_loss: 1.4022
Time for epoch 294 is 20.3751 sec
within save threshold. saving model.
Epoch 295, gen_loss: 0.8086, disc_loss: 1.3252
Time for epoch 295 is 21.0043 sec
within save threshold. saving model.
Epoch 296, gen_loss: 0.8512, disc_loss: 1.2489
Time for epoch 296 is 20.4423 sec
within save threshold. saving model.
Epoch 297, gen_loss: 0.7779, disc_loss: 1.3622
Time for epoch 297 is 20.8457 sec
within save threshold. saving model.
Epoch 298, gen_loss: 0.7876, disc_loss: 1.3320
Time for epoch 298 is 20.4629 sec
within save threshold. saving model.
Epoch 299, gen_loss: 0.7751, disc_loss: 1.3801
Time for 

within save threshold. saving model.
Epoch 344, gen_loss: 1.0426, disc_loss: 1.0442
Time for epoch 344 is 20.6872 sec
within save threshold. saving model.
Epoch 345, gen_loss: 1.0079, disc_loss: 1.0202
Time for epoch 345 is 20.8093 sec
within save threshold. saving model.
Epoch 346, gen_loss: 1.1644, disc_loss: 0.9672
Time for epoch 346 is 20.5950 sec
within save threshold. saving model.
Epoch 347, gen_loss: 0.9538, disc_loss: 1.1995
Time for epoch 347 is 20.4738 sec
within save threshold. saving model.
Epoch 348, gen_loss: 0.8948, disc_loss: 1.1243
Time for epoch 348 is 20.5761 sec
within save threshold. saving model.
Epoch 349, gen_loss: 1.0512, disc_loss: 1.0213
Time for epoch 349 is 20.8516 sec
within save threshold. saving model.
Epoch 350, gen_loss: 1.0759, disc_loss: 1.1133
Time for epoch 350 is 20.6400 sec
within save threshold. saving model.
Epoch 351, gen_loss: 1.1834, disc_loss: 0.8624
Time for epoch 351 is 20.4670 sec
within save threshold. saving model.
Epoch 352, gen_loss

Epoch 396, gen_loss: 0.8539, disc_loss: 1.2827
Time for epoch 396 is 20.5377 sec
within save threshold. saving model.
Epoch 397, gen_loss: 0.8964, disc_loss: 1.1872
Time for epoch 397 is 20.5089 sec
within save threshold. saving model.
Epoch 398, gen_loss: 0.8001, disc_loss: 1.3166
Time for epoch 398 is 20.6294 sec
within save threshold. saving model.
Epoch 399, gen_loss: 0.8418, disc_loss: 1.2668
Time for epoch 399 is 20.6136 sec
within save threshold. saving model.
Epoch 400, gen_loss: 0.7962, disc_loss: 1.2640
Time for epoch 400 is 20.9253 sec
within save threshold. saving model.
Epoch 401, gen_loss: 0.8525, disc_loss: 1.2578
Time for epoch 401 is 20.5705 sec
within save threshold. saving model.
Epoch 402, gen_loss: 0.7622, disc_loss: 1.4027
Time for epoch 402 is 20.6204 sec
within save threshold. saving model.
Epoch 403, gen_loss: 0.7139, disc_loss: 1.4332
Time for epoch 403 is 20.6768 sec
within save threshold. saving model.
Epoch 404, gen_loss: 0.7434, disc_loss: 1.4117
Time for 

within save threshold. saving model.
Epoch 449, gen_loss: 0.8057, disc_loss: 1.2523
Time for epoch 449 is 20.6878 sec
within save threshold. saving model.
Epoch 450, gen_loss: 0.7648, disc_loss: 1.3889
Time for epoch 450 is 28.6299 sec
within save threshold. saving model.
Epoch 451, gen_loss: 0.8545, disc_loss: 1.3392
Time for epoch 451 is 20.5572 sec
within save threshold. saving model.
Epoch 452, gen_loss: 0.8026, disc_loss: 1.2822
Time for epoch 452 is 20.6200 sec
within save threshold. saving model.
Epoch 453, gen_loss: 0.7807, disc_loss: 1.4106
Time for epoch 453 is 20.8581 sec
within save threshold. saving model.
Epoch 454, gen_loss: 0.7577, disc_loss: 1.3705
Time for epoch 454 is 20.8143 sec
within save threshold. saving model.
Epoch 455, gen_loss: 0.8177, disc_loss: 1.2804
Time for epoch 455 is 20.4318 sec
within save threshold. saving model.
Epoch 456, gen_loss: 0.8229, disc_loss: 1.3484
Time for epoch 456 is 20.5941 sec
within save threshold. saving model.
Epoch 457, gen_loss

Epoch 501, gen_loss: 0.8350, disc_loss: 1.2337
Time for epoch 501 is 20.5761 sec
within save threshold. saving model.
Epoch 502, gen_loss: 0.8078, disc_loss: 1.2829
Time for epoch 502 is 20.3986 sec
within save threshold. saving model.
Epoch 503, gen_loss: 0.9092, disc_loss: 1.1872
Time for epoch 503 is 20.5459 sec
within save threshold. saving model.
Epoch 504, gen_loss: 0.8282, disc_loss: 1.2192
Time for epoch 504 is 20.5600 sec
within save threshold. saving model.
Epoch 505, gen_loss: 0.9316, disc_loss: 1.2160
Time for epoch 505 is 20.6988 sec
within save threshold. saving model.
Epoch 506, gen_loss: 0.8465, disc_loss: 1.3166
Time for epoch 506 is 20.4481 sec
within save threshold. saving model.
Epoch 507, gen_loss: 0.8503, disc_loss: 1.2441
Time for epoch 507 is 20.6326 sec
within save threshold. saving model.
Epoch 508, gen_loss: 0.8587, disc_loss: 1.2769
Time for epoch 508 is 20.5397 sec
within save threshold. saving model.
Epoch 509, gen_loss: 0.8544, disc_loss: 1.3061
Time for 

within save threshold. saving model.
Epoch 554, gen_loss: 0.9261, disc_loss: 1.1082
Time for epoch 554 is 20.5653 sec
within save threshold. saving model.
Epoch 555, gen_loss: 0.9225, disc_loss: 1.1299
Time for epoch 555 is 20.5638 sec
within save threshold. saving model.
Epoch 556, gen_loss: 0.9549, disc_loss: 1.1846
Time for epoch 556 is 20.7892 sec
within save threshold. saving model.
Epoch 557, gen_loss: 0.7740, disc_loss: 1.4143
Time for epoch 557 is 20.4740 sec
within save threshold. saving model.
Epoch 558, gen_loss: 0.8896, disc_loss: 1.2166
Time for epoch 558 is 20.5950 sec
within save threshold. saving model.
Epoch 559, gen_loss: 0.8033, disc_loss: 1.3691
Time for epoch 559 is 20.4994 sec
within save threshold. saving model.
Epoch 560, gen_loss: 0.9048, disc_loss: 1.2088
Time for epoch 560 is 20.6653 sec
within save threshold. saving model.
Epoch 561, gen_loss: 0.8669, disc_loss: 1.3296
Time for epoch 561 is 20.7014 sec
within save threshold. saving model.
Epoch 562, gen_loss

Epoch 606, gen_loss: 1.0626, disc_loss: 1.0280
Time for epoch 606 is 20.6515 sec
within save threshold. saving model.
Epoch 607, gen_loss: 1.0234, disc_loss: 1.0189
Time for epoch 607 is 20.4289 sec
within save threshold. saving model.
Epoch 608, gen_loss: 0.9691, disc_loss: 1.1445
Time for epoch 608 is 20.9309 sec
within save threshold. saving model.
Epoch 609, gen_loss: 0.9537, disc_loss: 1.2484
Time for epoch 609 is 20.5219 sec
within save threshold. saving model.
Epoch 610, gen_loss: 0.9710, disc_loss: 1.1035
Time for epoch 610 is 20.5163 sec
within save threshold. saving model.
Epoch 611, gen_loss: 0.9569, disc_loss: 1.1694
Time for epoch 611 is 20.6980 sec
within save threshold. saving model.
Epoch 612, gen_loss: 0.9932, disc_loss: 1.1075
Time for epoch 612 is 20.4149 sec
within save threshold. saving model.
Epoch 613, gen_loss: 0.9804, disc_loss: 1.1700
Time for epoch 613 is 20.6701 sec
within save threshold. saving model.
Epoch 614, gen_loss: 1.0430, disc_loss: 1.0256
Time for 

within save threshold. saving model.
Epoch 659, gen_loss: 1.1259, disc_loss: 1.0043
Time for epoch 659 is 20.7670 sec
within save threshold. saving model.
Epoch 660, gen_loss: 1.0120, disc_loss: 1.1347
Time for epoch 660 is 20.6016 sec
within save threshold. saving model.
Epoch 661, gen_loss: 1.0914, disc_loss: 1.0322
Time for epoch 661 is 20.4310 sec
within save threshold. saving model.
Epoch 662, gen_loss: 1.0691, disc_loss: 0.9885
Time for epoch 662 is 20.7910 sec
within save threshold. saving model.
Epoch 663, gen_loss: 1.0941, disc_loss: 1.0128
Time for epoch 663 is 20.9820 sec
within save threshold. saving model.
Epoch 664, gen_loss: 0.9439, disc_loss: 1.2053
Time for epoch 664 is 20.5734 sec
within save threshold. saving model.
Epoch 665, gen_loss: 0.9540, disc_loss: 1.1371
Time for epoch 665 is 20.8763 sec
within save threshold. saving model.
Epoch 666, gen_loss: 1.0698, disc_loss: 1.1543
Time for epoch 666 is 20.5539 sec
within save threshold. saving model.
Epoch 667, gen_loss

Epoch 711, gen_loss: 0.9725, disc_loss: 1.1017
Time for epoch 711 is 20.6392 sec
within save threshold. saving model.
Epoch 712, gen_loss: 0.9817, disc_loss: 1.2519
Time for epoch 712 is 20.5055 sec
within save threshold. saving model.
Epoch 713, gen_loss: 0.8797, disc_loss: 1.3406
Time for epoch 713 is 20.6642 sec
within save threshold. saving model.
Epoch 714, gen_loss: 0.8677, disc_loss: 1.3266
Time for epoch 714 is 20.4961 sec
within save threshold. saving model.
Epoch 715, gen_loss: 0.8535, disc_loss: 1.2375
Time for epoch 715 is 20.6348 sec
within save threshold. saving model.
Epoch 716, gen_loss: 0.9469, disc_loss: 1.2250
Time for epoch 716 is 20.4109 sec
within save threshold. saving model.
Epoch 717, gen_loss: 0.8659, disc_loss: 1.3885
Time for epoch 717 is 20.4948 sec
within save threshold. saving model.
Epoch 718, gen_loss: 0.8510, disc_loss: 1.3615
Time for epoch 718 is 20.7681 sec
within save threshold. saving model.
Epoch 719, gen_loss: 0.8644, disc_loss: 1.2296
Time for 

within save threshold. saving model.
Epoch 764, gen_loss: 0.9726, disc_loss: 1.1589
Time for epoch 764 is 20.6041 sec
within save threshold. saving model.
Epoch 765, gen_loss: 1.0310, disc_loss: 1.1343
Time for epoch 765 is 20.4857 sec
within save threshold. saving model.
Epoch 766, gen_loss: 0.9673, disc_loss: 1.2264
Time for epoch 766 is 21.0472 sec
within save threshold. saving model.
Epoch 767, gen_loss: 1.0124, disc_loss: 1.0384
Time for epoch 767 is 20.6002 sec
within save threshold. saving model.
Epoch 768, gen_loss: 0.9103, disc_loss: 1.2596
Time for epoch 768 is 20.6499 sec
within save threshold. saving model.
Epoch 769, gen_loss: 0.9616, disc_loss: 1.1680
Time for epoch 769 is 20.5899 sec
within save threshold. saving model.
Epoch 770, gen_loss: 0.9336, disc_loss: 1.1764
Time for epoch 770 is 20.7788 sec
within save threshold. saving model.
Epoch 771, gen_loss: 1.0630, disc_loss: 1.0919
Time for epoch 771 is 20.6407 sec
within save threshold. saving model.
Epoch 772, gen_loss

Epoch 816, gen_loss: 1.0749, disc_loss: 1.0028
Time for epoch 816 is 20.4441 sec
within save threshold. saving model.
Epoch 817, gen_loss: 1.0871, disc_loss: 1.1041
Time for epoch 817 is 20.5520 sec
within save threshold. saving model.
Epoch 818, gen_loss: 0.9539, disc_loss: 1.2748
Time for epoch 818 is 20.8083 sec
within save threshold. saving model.
Epoch 819, gen_loss: 0.9150, disc_loss: 1.3200
Time for epoch 819 is 20.4415 sec
within save threshold. saving model.
Epoch 820, gen_loss: 0.8952, disc_loss: 1.2181
Time for epoch 820 is 20.5998 sec
within save threshold. saving model.
Epoch 821, gen_loss: 0.9556, disc_loss: 1.3694
Time for epoch 821 is 20.4892 sec
within save threshold. saving model.
Epoch 822, gen_loss: 0.9407, disc_loss: 1.2249
Time for epoch 822 is 20.6224 sec
within save threshold. saving model.
Epoch 823, gen_loss: 0.8124, disc_loss: 1.4099
Time for epoch 823 is 20.7824 sec
within save threshold. saving model.
Epoch 824, gen_loss: 0.9576, disc_loss: 1.2561
Time for 

within save threshold. saving model.
Epoch 869, gen_loss: 0.6991, disc_loss: 1.5136
Time for epoch 869 is 20.7670 sec
within save threshold. saving model.
Epoch 870, gen_loss: 0.9061, disc_loss: 1.1626
Time for epoch 870 is 20.6647 sec
within save threshold. saving model.
Epoch 871, gen_loss: 0.9853, disc_loss: 1.2186
Time for epoch 871 is 20.5973 sec
within save threshold. saving model.
Epoch 872, gen_loss: 1.0054, disc_loss: 1.1252
Time for epoch 872 is 20.6505 sec
within save threshold. saving model.
Epoch 873, gen_loss: 0.8864, disc_loss: 1.2313
Time for epoch 873 is 21.0332 sec
within save threshold. saving model.
Epoch 874, gen_loss: 0.9014, disc_loss: 1.2401
Time for epoch 874 is 20.6729 sec
within save threshold. saving model.
Epoch 875, gen_loss: 0.9199, disc_loss: 1.2576
Time for epoch 875 is 20.5851 sec
within save threshold. saving model.
Epoch 876, gen_loss: 0.8903, disc_loss: 1.2997
Time for epoch 876 is 20.5219 sec
within save threshold. saving model.
Epoch 877, gen_loss

Epoch 921, gen_loss: 0.9516, disc_loss: 1.1707
Time for epoch 921 is 21.6191 sec
within save threshold. saving model.
Epoch 922, gen_loss: 1.0175, disc_loss: 1.1009
Time for epoch 922 is 21.2031 sec
within save threshold. saving model.
Epoch 923, gen_loss: 0.9830, disc_loss: 1.0785
Time for epoch 923 is 21.2757 sec
within save threshold. saving model.
Epoch 924, gen_loss: 1.0873, disc_loss: 1.0070
Time for epoch 924 is 21.5980 sec
within save threshold. saving model.
Epoch 925, gen_loss: 0.8774, disc_loss: 1.2971
Time for epoch 925 is 20.9738 sec
within save threshold. saving model.
Epoch 926, gen_loss: 1.0260, disc_loss: 1.1052
Time for epoch 926 is 21.9399 sec
within save threshold. saving model.
Epoch 927, gen_loss: 0.9217, disc_loss: 1.2932
Time for epoch 927 is 21.0538 sec
within save threshold. saving model.
Epoch 928, gen_loss: 0.9918, disc_loss: 1.1062
Time for epoch 928 is 20.7435 sec
within save threshold. saving model.
Epoch 929, gen_loss: 1.0991, disc_loss: 1.0290
Time for 

within save threshold. saving model.
Epoch 974, gen_loss: 0.9529, disc_loss: 1.2130
Time for epoch 974 is 20.7610 sec
within save threshold. saving model.
Epoch 975, gen_loss: 1.1002, disc_loss: 1.0279
Time for epoch 975 is 20.9420 sec
within save threshold. saving model.
Epoch 976, gen_loss: 1.1785, disc_loss: 0.9068
Time for epoch 976 is 20.6569 sec
within save threshold. saving model.
Epoch 977, gen_loss: 1.1015, disc_loss: 1.0540
Time for epoch 977 is 20.7559 sec
within save threshold. saving model.
Epoch 978, gen_loss: 1.2981, disc_loss: 0.9862
Time for epoch 978 is 20.4558 sec
within save threshold. saving model.
Epoch 979, gen_loss: 1.1123, disc_loss: 1.0618
Time for epoch 979 is 20.8580 sec
within save threshold. saving model.
Epoch 980, gen_loss: 0.9880, disc_loss: 1.3035
Time for epoch 980 is 20.8043 sec
within save threshold. saving model.
Epoch 981, gen_loss: 1.1162, disc_loss: 1.0788
Time for epoch 981 is 20.5611 sec
within save threshold. saving model.
Epoch 982, gen_loss

In [285]:
def test_caption2string(ls):
    return " ".join([id2word_dict[idx] for idx in ls]).strip().split(' <PAD>')[0]

def testing_data_generator(caption, index, caption_type='id'):
    if caption_type == 'id':
        caption = tf.cast(caption, tf.float32)
    return caption, index

def testing_dataset_generator(batch_size, data_generator, caption_type='id'):
    data = pd.read_pickle('./dataset/testData.pkl')
    
    if caption_type == 'sentence':
        data['Captions_string'] = data['Captions'].apply(test_caption2string)
        captions = data['Captions_string'].values
    elif caption_type == 'id':
        captions = data['Captions'].values
        
    caption = []
    for i in range(len(captions)):
        caption.append(captions[i])
    caption = np.asarray(caption)
    
    if caption_type == 'id':
        caption = caption.astype(np.int)
        
    datagen_func = lambda cap, img: data_generator(cap, img, caption_type=caption_type)
        
    index = data['ID'].values
    index = np.asarray(index)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, index))
    dataset = dataset.map(datagen_func, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat().batch(batch_size)
    
    return dataset

In [286]:
testing_dataset = testing_dataset_generator(hparas['BATCH_SIZE'], testing_data_generator, caption_type=expSettings.caption_type)

In [287]:
# test the testing dataset
for cap, img in testing_dataset.take(1):
    print("Image shape:", img.numpy().shape)
    print("Caption shape:", cap.numpy().shape)

Image shape: (64,)
Caption shape: (64,)


In [288]:
data = pd.read_pickle('./dataset/testData.pkl')
captions = data['Captions'].values

NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / hparas['BATCH_SIZE'])

In [289]:
if not os.path.exists('./inference/demo'):
    os.makedirs('./inference/demo')

In [290]:
def inference(dataset):
    hidden = text_encoder.initialize_hidden_state()
    sample_size = hparas['BATCH_SIZE']
    sample_seed = np.random.normal(loc=0.0, scale=0.1, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)
    print(sample_seed[0:3, :])
    
    step = 0
    start = time.time()
    for captions, idx in dataset:
        if step > EPOCH_TEST:
            break
        
        fake_image = test_step(captions, sample_seed, hidden)
        step += 1
        for i in range(hparas['BATCH_SIZE']):
            plt.imsave('./inference/demo/inference_{:04d}.jpg'.format(idx[i]), fake_image[i].numpy()*0.5 + 0.5)
            
            if i == 0 and step == 1: 
                #print(captions)
                text_embed_t, hidden_t = text_encoder(captions, hidden)
                #print(text_embed_t)
                print(fake_image[0:1, 0:5, 0:5, :])
                img_logits, _, debug = generator(text_embed_t, sample_seed, debug_output=True)
                print(debug[0][0:3, 0:5, 0:5, 0:5])
                print(debug[1][0:3, 0:5, 0:5, 0:5])
                print(debug[2][0:3, 0:5, 0:5, 0:5])
                print(debug[3][0:3, 0:5, 0:5, 0:5])
                print(debug[4][0:3, 0:5, 0:5, 0:5])
                print(img_logits[0:3, 0:5, 0:5, :])
                pred_logit, pred = discriminator(fake_image, text_embed_t)
                print(pred_logit)
                
            
    print('Time for inference is {:.4f} sec'.format(time.time()-start))

In [291]:
#checkpoint.restore(checkpoint_dir + f'/ckpt-50')
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
    print(f"Restoring from {latest_checkpoint}")
    checkpoint.restore(latest_checkpoint)
else:
    print("No checkpoint found.")

Restoring from ./checkpoints/demo\ckpt-891


In [292]:
inference(testing_dataset)

[[-1.76590890e-01 -3.62670831e-02 -5.88861965e-02 ... -1.77671053e-02
   1.72591850e-01  3.02505605e-02]
 [-8.34661871e-02  3.57552022e-02  1.00501038e-05 ...  1.96811911e-02
  -4.91746794e-03  1.02292605e-01]
 [-3.03284191e-02  1.11609578e-01  4.10925858e-02 ...  2.25200132e-02
  -1.43049583e-01  1.30498037e-01]]
tf.Tensor(
[[[[ 0.09060197  0.12390514  0.10596064]
   [ 0.13115035  0.13739027  0.09484114]
   [ 0.0974274   0.07918742  0.06754562]
   [ 0.08504594  0.07736155  0.05447872]
   [ 0.06964721  0.04440857  0.02915025]]

  [[ 0.11267444  0.09133928  0.10059573]
   [-0.02394733 -0.07115721 -0.01652582]
   [ 0.07183935  0.05187262  0.04409298]
   [-0.03050993 -0.08891862 -0.02702138]
   [ 0.05989148  0.03450191  0.03503729]]

  [[ 0.0426882   0.03691719  0.044337  ]
   [ 0.04776705  0.03963728  0.03349791]
   [ 0.05904258  0.04055943  0.0366561 ]
   [ 0.05226436  0.04220465  0.03731665]
   [ 0.06617492  0.0443062   0.03730223]]

  [[ 0.02373693  0.02568267  0.04221135]
   [-0.0545

tf.Tensor(
[[ 0.7330474 ]
 [ 0.6364398 ]
 [-0.80292803]
 [-1.9607987 ]
 [-0.03298634]
 [ 0.6996903 ]
 [-0.10833185]
 [ 0.31654674]
 [ 0.6588001 ]
 [-0.09422082]
 [ 0.5578816 ]
 [-0.05290177]
 [-0.11813443]
 [-0.19716708]
 [ 0.0263162 ]
 [ 0.29462236]
 [ 0.07541651]
 [ 0.03791507]
 [-1.301384  ]
 [ 0.01705541]
 [ 0.2800377 ]
 [-0.03757518]
 [-0.1446654 ]
 [ 0.571592  ]
 [ 0.16298202]
 [ 0.3853119 ]
 [-0.3280385 ]
 [ 0.28889844]
 [-0.01825701]
 [-0.13764481]
 [ 0.96193516]
 [-1.4912162 ]
 [ 0.37523788]
 [-0.08545094]
 [-0.69500977]
 [-0.08711087]
 [-0.3767596 ]
 [-0.01851731]
 [-0.10279977]
 [-1.6155019 ]
 [-0.08037608]
 [ 1.3862095 ]
 [-0.14765243]
 [ 0.01453615]
 [-0.06687289]
 [-1.7731142 ]
 [ 0.40818334]
 [ 0.6745443 ]
 [ 0.5628932 ]
 [ 0.67801124]
 [-0.2800907 ]
 [ 0.6571805 ]
 [-0.2994401 ]
 [-0.07143586]
 [-0.12705807]
 [ 0.4541952 ]
 [-0.15048783]
 [ 0.6623275 ]
 [ 0.23542286]
 [-0.0028184 ]
 [ 0.75507957]
 [-0.67771035]
 [ 0.07533414]
 [-0.8166626 ]], shape=(64, 1), dtype=float3

In [293]:
for var in generator.starter.variables:
    print(var.name, var.numpy())

generator_8/dense_block_8/dense_50/kernel:0 [[-0.04498814 -0.23571591  0.05061483 ... -0.07926965 -0.06415971
  -0.08199308]
 [ 0.05829757  0.21634753  0.02969032 ...  0.02306199  0.16618606
   0.02117582]
 [ 0.03605476 -0.06150118  0.07170997 ...  0.16611832  0.10298007
   0.02841652]
 ...
 [-0.00112796  0.00504179  0.00708564 ... -0.02090279 -0.00069093
  -0.02018782]
 [ 0.01602916 -0.018421   -0.00192075 ... -0.01871504  0.00557119
  -0.00656456]
 [-0.02177091  0.01424693  0.00438795 ... -0.01453981 -0.02062936
   0.01695392]]
generator_8/dense_block_8/dense_50/bias:0 [-8.6428727e-06 -6.6864864e-05  1.8376473e-05 ... -1.8091770e-05
 -8.9352479e-06 -3.2261182e-06]
generator_8/dense_block_8/batch_normalization_128/gamma:0 [0.9478327  0.94250154 1.071583   ... 1.2639962  1.1563675  1.3050067 ]
generator_8/dense_block_8/batch_normalization_128/beta:0 [ 0.14056788 -0.32281476  0.19005923 ... -0.0251978  -0.03710499
 -0.2400612 ]
generator_8/dense_block_8/batch_normalization_128/moving_me

In [294]:
for var in discriminator.variables:
    print(var.name, var.numpy())

discriminator_8/conv_block_82/conv2d_82/kernel:0 [[[[-2.2866993e-01  2.4309792e-01 -5.7593203e-04 ...  5.0765681e-01
    -1.8299033e-01  7.4656770e-02]
   [-4.9829939e-01 -4.8093405e-01 -1.8522252e-01 ... -5.3269726e-01
     3.6122376e-01  2.9018033e-01]
   [-1.3226913e-01 -1.8454133e-02  3.7295622e-01 ... -4.2382467e-01
     1.7720783e-01  9.5634125e-03]]

  [[ 1.2448568e-01 -3.8422987e-02  6.4998500e-02 ...  1.7700158e-02
    -3.2616064e-01 -2.0630422e-01]
   [ 3.8726839e-01 -2.8708821e-02 -1.5146399e-01 ...  2.0324646e-01
     3.4477569e-02  6.3980669e-03]
   [-1.6381900e-01  3.6336344e-01 -4.8993415e-01 ... -2.5162059e-01
     1.1920709e-01 -1.4812057e-01]]

  [[-6.4685792e-02  1.2510049e-01  6.7997612e-02 ... -2.0516430e-01
    -1.2416184e-02 -4.4040132e-01]
   [ 1.1142460e-01 -8.6674616e-02  2.5821075e-01 ... -3.1199932e-01
     8.7472536e-02 -2.3021638e-01]
   [ 1.6607173e-01 -3.4756833e-01 -1.9862036e-01 ...  4.5215395e-01
     1.2540741e-01 -1.5188457e-01]]]


 [[[ 3.5060322e-

In [295]:
%cd ./testing
!python inception_score.py ../inference/demo output.csv 21
%cd ../

C:\Users\User\Courses\24aut_deep_learning\24aut-deep-learning\deep-learning-comp3\testing
1 Physical GPUs, 1 Logical GPUs
--------------Evaluation Success-----------------
C:\Users\User\Courses\24aut_deep_learning\24aut-deep-learning\deep-learning-comp3


In [296]:
text_encoder.summary()

Model: "text_encoder_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_8 (Embedding)     multiple                  0 (unused)
                                                                 
 gru_8 (GRU)                 multiple                  0 (unused)
                                                                 
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________


In [297]:
generator.summary()

Model: "generator_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_16 (Flatten)        multiple                  0 (unused)
                                                                 
 dense_48 (Dense)            multiple                  0 (unused)
                                                                 
 dense_49 (Dense)            multiple                  0 (unused)
                                                                 
 dense_block_8 (dense_block)  multiple                 2107392   
                                                                 
 deconv_block_40 (deconv_blo  multiple                 4196864   
 ck)                                                             
                                                                 
 deconv_block_41 (deconv_blo  multiple                 2098432   
 ck)                                                   

In [298]:
discriminator.summary()

Model: "discriminator_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv_block_82 (conv_block)  multiple                  16384     
                                                                 
 conv_block_83 (conv_block)  multiple                  264704    
                                                                 
 dropout_12 (Dropout)        multiple                  0         
                                                                 
 conv_block_84 (conv_block)  multiple                  1180928   
                                                                 
 conv_block_85 (conv_block)  multiple                  66816     
                                                                 
 dropout_13 (Dropout)        multiple                  0         
                                                                 
 conv_block_86 (conv_block)  multiple              