# Competition 3: Team 21

112062649 王俊皓

112062650 廖士傑

##  Reverse Image Caption

In [634]:
class experimental_settings:
    def __init__(self,
                 enc=True,
                 gen=True,
                 dis=True,
                 enc_do_batchnorm=False,
                 delete_checkpoint=False):
        self.enc = enc
        self.gen = gen
        self.dis = dis
        self.enc_do_batchnorm = enc_do_batchnorm
        self.delete_checkpoint = delete_checkpoint # not implemented yet
        
        # ============================ #
        # automatic
        # ============================ #
        
        self.caption_type = 'sentence' if self.enc else 'id'


expSettings = experimental_settings(enc=True,
                                    gen=True,
                                    dis=True,
                                    delete_checkpoint=True)

## Import

In [635]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras import layers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import string
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import random
import time
from pathlib import Path
import math

import re
from IPython import display

GPU check

In [636]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Restrict TensorFlow to only use the first GPU
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


## Loading Data

In [637]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))

there are 5427 vocabularies in total
Word to id mapping, for example: flower -> 1
Id to word mapping, for example: 1 -> flower
Tokens: <PAD>: 5427; <RARE>: 5428


In [638]:
def sent2IdList(line, MAX_SEQ_LENGTH=20):
    MAX_SEQ_LIMIT = MAX_SEQ_LENGTH
    padding = 0
    
    # data preprocessing, remove all puntuation in the texts
    prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('  ', ' ')
    prep_line = prep_line.replace('.', '')
    tokens = prep_line.split(' ')
    tokens = [
        tokens[i] for i in range(len(tokens))
        if tokens[i] != ' ' and tokens[i] != ''
    ]
    l = len(tokens)
    padding = MAX_SEQ_LIMIT - l
    
    # make sure length of each text is equal to MAX_SEQ_LENGTH, and replace the less common word with <RARE> token
    for i in range(padding):
        tokens.append('<PAD>')
    line = [
        word2Id_dict[tokens[k]]
        if tokens[k] in word2Id_dict else word2Id_dict['<RARE>']
        for k in range(len(tokens))
    ]

    return line

text = "the flower shown has yellow anther red pistil and bright red petals."
print(text)
print(sent2IdList(text))

the flower shown has yellow anther red pistil and bright red petals.
['9', '1', '82', '5', '11', '70', '20', '31', '3', '29', '20', '2', '5427', '5427', '5427', '5427', '5427', '5427', '5427', '5427']


In [639]:
@tf.function
def id2Sent(ids):
    return " ".join([id2word_dict[idx] for idx in ids]).strip()

#def batch_id2Sent(batch_ids):
    #return [id2Sent(ids) for ids in batch_ids]
    
def batch_id2Sent(batch_ids):
    def process_single(ids):
        # Convert a single tensor of IDs to a sentence
        ids = ids.numpy()  # Convert Tensor to NumPy
        sentence = " ".join([id2word_dict.get(idx, "<UNK>") for idx in ids])  # Handle unknown IDs
        return sentence

    # Use tf.py_function to apply Python function inside the TensorFlow graph
    sentences = tf.map_fn(
        lambda ids: tf.py_function(process_single, [ids], tf.string),
        batch_ids,
        fn_output_signature=tf.string
    )
    return sentences


print(sent2IdList(text))
print(id2Sent(sent2IdList(text)))

['9', '1', '82', '5', '11', '70', '20', '31', '3', '29', '20', '2', '5427', '5427', '5427', '5427', '5427', '5427', '5427', '5427']
tf.Tensor(b'the flower shown has yellow anther red pistil and bright red petals <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>', shape=(), dtype=string)


In [640]:
data_path = './dataset'
text2ImgData = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(text2ImgData)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))

There are 7370 image in training data


In [641]:
def caption2string(cap):
    output = []
    for sen in cap:
        s = " ".join([id2word_dict[idx] for idx in sen]).strip()
        output.append(s.split(' <PAD>')[0])
    return output

# adding caption as strings
text2ImgData['Captions_string'] = text2ImgData['Captions'].apply(caption2string)

In [642]:
text2ImgData.head(5)

Unnamed: 0_level_0,Captions,ImagePath,Captions_string
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
6734,"[[9, 2, 17, 9, 1, 6, 14, 13, 18, 3, 41, 8, 11,...",./102flowers/image_06734.jpg,[the petals of the flower are pink in color an...
6736,"[[4, 1, 5, 12, 2, 3, 11, 31, 28, 68, 106, 132,...",./102flowers/image_06736.jpg,[this flower has white petals and yellow pisti...
6737,"[[9, 2, 27, 4, 1, 6, 14, 7, 12, 19, 5427, 5427...",./102flowers/image_06737.jpg,[the petals on this flower are pink with white...
6738,"[[9, 1, 5, 8, 54, 16, 38, 7, 12, 116, 325, 3, ...",./102flowers/image_06738.jpg,[the flower has a smooth purple petal with whi...
6739,"[[4, 12, 1, 5, 29, 11, 19, 7, 26, 70, 5427, 54...",./102flowers/image_06739.jpg,[this white flower has bright yellow stamen wi...


In [643]:
text2ImgData['Captions_string'][:1].tolist()

[['the petals of the flower are pink in color and have a yellow center',
  'this flower is pink and white in color with petals that are multi colored',
  'the purple petals have shades of white with white anther and filament',
  'this flower has large pink petals and a white stigma in the center',
  'this flower has petals that are pink and has a yellow stamen',
  'a flower with short and wide petals that is light purple',
  'this flower has small pink petals with a yellow center',
  'this flower has large rounded pink petals with curved edges and purple veins',
  'this flower has purple petals as well as a white stamen']]

In [644]:
# in this competition, you have to generate image in size 64x64x3
IMAGE_SIZE = 64
IMAGE_HEIGHT = IMAGE_SIZE
IMAGE_WIDTH = IMAGE_SIZE
IMAGE_CHANNEL = 3

def training_data_generator(caption, image_path, caption_type='id'):
    # load in the image according to image path
    img = tf.io.read_file(image_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img.set_shape([None, None, 3])
    img = tf.image.resize(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
    img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    if caption_type == 'id':
        caption = tf.cast(caption, tf.int32)
    elif caption_type == 'sentence':
        caption = tf.convert_to_tensor(caption, dtype=tf.string)

    return img, caption

def dataset_generator(filenames, batch_size, data_generator, caption_type='id'):
    # load the training data into two NumPy arrays
    if filenames != None:
        df = pd.read_pickle(filenames)
    else:
        df = text2ImgData
    
    if caption_type == 'id':
        captions = df['Captions'].values
    elif caption_type == 'sentence':
        captions = df['Captions_string'].values
    else:
        raise ValueError('for dataset_generator, caption_type= should be \'id\' or \'sentence\'.')
        
    caption = []
    # each image has 1 to 10 corresponding captions
    # we choose one of them randomly for training
    
    # ============================================ #
    # TODO: augmentation
    # idea 1 (difficulty: easy)
    #     training data has multiple captions, right now it picks a random one.
    #     we can make it so that every caption is an entry and multiple captions link to the same image.
    # idea 2 (difficulty: medium)
    #     after text embedding, use the average of 2 caption embeddings to generate a new caption.
    #     the data does not need to have an image tied to it, it just have the label 0 (fake image).
    # ============================================ #
    for i in range(len(captions)):
        caption.append(random.choice(captions[i]))
    caption = np.asarray(caption)
    
    if caption_type == 'id':
        caption = caption.astype(np.int)
        
    image_path = df['ImagePath'].values
    
    # assume that each row of `features` corresponds to the same row as `labels`.
    assert caption.shape[0] == image_path.shape[0]
    
    datagen_func = lambda cap, img: data_generator(cap, img, caption_type=caption_type)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, image_path))
    dataset = dataset.map(datagen_func, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(len(caption)).batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

In [645]:
BATCH_SIZE = 64
dataset = dataset_generator(
    #data_path + '/text2ImgData.pkl',
    None,
    BATCH_SIZE, 
    training_data_generator, 
    caption_type=expSettings.caption_type)

In [646]:
# dataset testing ground
for img, cap in dataset.take(1):
    print("Image shape:", img.numpy().shape)
    print("Caption shape:", cap.numpy().shape)

Image shape: (64, 64, 64, 3)
Caption shape: (64,)


In [647]:
from tensorflow.keras.layers import Conv2DTranspose, Conv2D, BatchNormalization, LeakyReLU, Dense
from tensorflow.keras.initializers import HeNormal

# custom layers
class flattened_dense(tf.keras.layers.Layer):
    """
    a dense layer that is made compatible with convolution layers
    by flattening the input first and followed by a dense layer.
    """
    def __init__(self, channels=64, kernel_initializer="glorot_uniform"):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(channels, kernel_initializer=kernel_initializer)
        
    def call(self, inputs):
        fl = self.flatten(inputs)
        return self.dense(fl)
    
class conv_block(tf.keras.layers.Layer):
    """
    a convolution layer with batch normalization and leaky relu activation
    """
    def __init__(self, filters=128, kernel_size=1, strides=1, kernel_initializer='glorot_uniform'):
        super().__init__()
        self.conv = Conv2D(filters=filters,
                           kernel_size = (kernel_size, kernel_size),
                           strides=(strides, strides),
                           padding='same',
                           kernel_initializer=kernel_initializer)
        self.bn = BatchNormalization()
        self.activation = LeakyReLU(alpha=0.1)
    
    def call(self, inputs):
        return self.activation(self.bn(self.conv(inputs)))
    
class deconv_block(tf.keras.layers.Layer):
    """
    a deconvolution layer with batch normalization and leaky relu activation
    """
    def __init__(self, filters=128, kernel_size=4, strides=2, kernel_initializer='glorot_uniform'):
        super().__init__()
        self.deconv = Conv2DTranspose(filters=filters,
                                    kernel_size = (kernel_size, kernel_size),
                                    strides=(strides, strides),
                                    padding='same',
                                    kernel_initializer=kernel_initializer)
        self.bn = BatchNormalization()
        self.activation = LeakyReLU(alpha=0.1)
    
    def call(self, inputs):
        return self.activation(self.bn(self.deconv(inputs)))
    
class dense_block(tf.keras.layers.Layer):
    """
    a dense layer with batch normalization and leaky relu activation
    """
    def __init__(self, filters=128, kernel_initializer='glorot_uniform'):
        super().__init__()
        self.d = Dense(filters, kernel_initializer=kernel_initializer)
        self.bn = BatchNormalization()
        self.activation = LeakyReLU(alpha=0.1)
    
    def call(self, inputs):
        outputs = self.d(inputs)
        outputs = self.bn(outputs)
        return self.activation(outputs)

In [648]:
import tensorflow_hub as hub

class TextEncoder(tf.keras.Model):
    """
    Encode text (a caption) into hidden representation
    input: text, which is a list of ids
    output: embedding, or hidden representation of input text in dimension of RNN_HIDDEN_SIZE
    """
    def __init__(self, hparas, experimental=False, do_batchnorm=False):
        super(TextEncoder, self).__init__()
        self.exp=experimental
        self.do_batchnorm = do_batchnorm
        self.hparas = hparas
        self.batch_size = self.hparas['BATCH_SIZE']
        
        # embedding with tensorflow API
        self.embedding = layers.Embedding(self.hparas['VOCAB_SIZE'], self.hparas['EMBED_DIM'])
        # RNN, here we use GRU cell, another common RNN cell similar to LSTM
        self.gru = layers.GRU(self.hparas['RNN_HIDDEN_SIZE'],
                              return_sequences=True,
                              return_state=True,
                              recurrent_initializer='glorot_uniform')
        if self.exp:
            self.embed = hub.load('./checkpoints/universal_sentence_encoder')
    
    def call(self, text, hidden):
        if self.exp:
            with tf.device('/CPU:0'): # TODO if you find a way to use GPU, go for it.
                output_last = self.embed(text)
                
            state = hidden # not updating state for compatibility reasons
            
        else:
            text = self.embedding(text)
            output, state = self.gru(text, initial_state = hidden)
            output_last = output[:, -1, :]
        
        # normalization in-batch
        if self.do_batchnorm:
            mean = tf.reduce_mean(output_last, axis=0, keepdims=True)  # Mean across the batch
            std = tf.math.reduce_std(output_last, axis=0, keepdims=True)  # Std across the batch
            normalized = (output_last - mean) / (std + 1e-6)  # Avoid division by zero
        else:
            normalized = output_last
        
        return normalized, state
    
    def initialize_hidden_state(self):
        return tf.zeros((self.hparas['BATCH_SIZE'], self.hparas['RNN_HIDDEN_SIZE']))

In [649]:
class Generator(tf.keras.Model):
    """
    Generate fake image based on given text(hidden representation) and noise z
    input: text and noise
    output: fake image with size 64*64*3
    """
    def __init__(self, hparas, experimental=False):
        super(Generator, self).__init__()
        self.exp = experimental
        self.hparas = hparas
        self.flatten = tf.keras.layers.Flatten()
        self.d1 = tf.keras.layers.Dense(self.hparas['DENSE_DIM'], kernel_initializer="glorot_uniform")
        self.d2 = tf.keras.layers.Dense(64*64*3, kernel_initializer="glorot_uniform")
        if self.exp:
            self.deconv_depth = int(math.log(IMAGE_SIZE, 2)) - 1
            self.starter = dense_block(filters=2*2*256)
            self.deconv = [deconv_block(filters=(2 ** (8 - i)), kernel_initializer=HeNormal()) for i in range(self.deconv_depth)]
            self.headf = conv_block(filters=3, kernel_size=1, strides=1)
            
    def call(self, text, noise_z, debug_output=False):
        # deconvolution
        if self.exp:
            noisy_text = tf.concat([text, noise_z], axis=1) * 10 # amplify the input a bit, they seem fairly close to 0.
            img = self.starter(noisy_text)
            
            img = tf.reshape(img, [-1, 2, 2, 256])
            debug = []
            for i in range(self.deconv_depth):
                debug.append(img)
                img = self.deconv[i](img)

            img = self.headf(img)
            logits = tf.reshape(img, [-1, IMAGE_SIZE, IMAGE_SIZE, 3])
            output = tf.nn.tanh(logits)
        
        # concatenate input text and random noise
        else:
            text = self.flatten(text)
            text = self.d1(text)
            text = tf.nn.leaky_relu(text)
            text_concat = tf.concat([noise_z, text], axis=1)
            text_concat = self.d2(text_concat)
        
            logits = tf.reshape(text_concat, [-1, 64, 64, 3])
            output = tf.nn.tanh(logits)
            debug_output = output
        
        if debug_output:
            return logits, output, debug
        else:
            return logits, output

In [650]:
from tensorflow.keras.applications import ResNet50

class Discriminator(tf.keras.Model):
    """
    Differentiate the real and fake image
    input: image and corresponding text
    output: labels, the real image should be 1, while the fake should be 0
    """
    def __init__(self, hparas, experimental=False):
        super(Discriminator, self).__init__()
        self.exp = experimental
        self.hparas = hparas
        if self.exp:
            #self.resnet_base = ResNet50(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), weights='imagenet', include_top=False)
            #for layer in self.resnet_base.layers:
                #layer.trainable = False
            self.conv1 = conv_block(filters=256, kernel_size=3, strides=1)
            self.conv2 = conv_block(filters=64, kernel_size=3, strides=1)
        self.flatten = tf.keras.layers.Flatten()
        self.d_text = tf.keras.layers.Dense(self.hparas['DENSE_DIM'])
        self.d_img = tf.keras.layers.Dense(self.hparas['DENSE_DIM'])
        self.d = tf.keras.layers.Dense(1)
    
    def call(self, img, text):
        text = self.flatten(text)
        text = self.d_text(text)
        text = tf.nn.leaky_relu(text)
        
        if self.exp:
            #img = self.resnet_base(img)
            img = self.conv1(img)
            img = self.conv2(img)
        img = self.flatten(img)
        img = self.d_img(img)
        img = tf.nn.leaky_relu(img)
        
        # concatenate image with paired text
        img_text = tf.concat([text, img], axis=1)
        
        logits = self.d(img_text)
        output = tf.nn.sigmoid(logits)
        
        return logits, output

Parameters and settings

In [651]:
hparas = {
    'MAX_SEQ_LENGTH': 20,                     # maximum sequence length
    'EMBED_DIM': 256,                         # word embedding dimension
    'VOCAB_SIZE': len(word2Id_dict),          # size of dictionary of captions
    'RNN_HIDDEN_SIZE': 128,                   # number of RNN neurons
    'Z_DIM': 512,                             # random noise z dimension
    'DENSE_DIM': 128,                         # number of neurons in dense layer
    'IMAGE_SIZE': [64, 64, 3],                # render image size
    'BATCH_SIZE': 64,
    'LR_GEN': 1e-3,
    'LR_DIS': 1e-4,
    'LR_DECAY': 0.5,                          # unused
    'BETA_1': 0.5,
    'N_EPOCH': 1000,                            # number of epoch for demo
    'N_SAMPLE': num_training_sample,          # size of training data
    'CHECKPOINTS_DIR': './checkpoints/demo',  # checkpoint path
    'PRINT_FREQ': 1                           # printing frequency of loss
}

In [652]:
text_encoder = TextEncoder(hparas, 
                           experimental=expSettings.enc,
                           do_batchnorm=expSettings.enc_do_batchnorm)

generator = Generator(hparas,
                      experimental=expSettings.gen)

discriminator = Discriminator(hparas,
                              experimental=expSettings.dis)

In [653]:
# test text encoder
for img, cap in dataset.take(1):
    print("Image shape:", img.numpy().shape)
    print("Caption shape:", cap.shape)
    with tf.device('/CPU:0'):
        output, _ = text_encoder(cap, 0)
        print("Caption embed shape:", output.shape)

Image shape: (64, 64, 64, 3)
Caption shape: (64,)
Caption embed shape: (64, 512)


In [654]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [655]:
def discriminator_loss(real_logits, fake_logits):
    # output value of real image should be 1
    real_loss = cross_entropy(tf.ones_like(real_logits), real_logits)
    # output value of fake image should be 0
    fake_loss = cross_entropy(tf.zeros_like(fake_logits), fake_logits)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    # output value of fake image should be 0
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [656]:
# we use seperated optimizers for training generator and discriminator
generator_optimizer = tf.keras.optimizers.Adam(hparas['LR_GEN'], clipvalue=2.0)
discriminator_optimizer = tf.keras.optimizers.Adam(hparas['LR_DIS'], clipvalue=0.1)

In [657]:
if expSettings.delete_checkpoint:
    for f in os.listdir(hparas['CHECKPOINTS_DIR']):
        file_path = os.path.join(hparas['CHECKPOINTS_DIR'], f)
        if os.path.isfile(file_path):
            os.unlink(file_path)

# one benefit of tf.train.Checkpoint() API is we can save everything seperately
checkpoint_dir = hparas['CHECKPOINTS_DIR']
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 text_encoder=text_encoder,
                                 generator=generator,
                                 discriminator=discriminator)
ckptManager = tf.train.CheckpointManager(
    checkpoint, directory=hparas['CHECKPOINTS_DIR'], max_to_keep=5)

In [658]:
@tf.function
def train_step(real_image, caption, hidden, imshow=False):
    # random noise for generator
    noise = tf.random.normal(shape=[hparas['BATCH_SIZE'], hparas['Z_DIM']], mean=0.0, stddev=0.1)
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        text_embed, hidden = text_encoder(caption, hidden)
        _, fake_image = generator(text_embed, noise)
        if imshow:
            plt.imshow(fake_image[0])

        real_logits, real_output = discriminator(real_image, text_embed)
        fake_logits, fake_output = discriminator(fake_image, text_embed)

        g_loss = generator_loss(fake_output)
        d_loss = discriminator_loss(real_logits, fake_logits)

    grad_g = gen_tape.gradient(g_loss, generator.trainable_variables)
    grad_d = disc_tape.gradient(d_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(grad_g, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(grad_d, discriminator.trainable_variables))
    
    return g_loss, d_loss

In [659]:
@tf.function
def test_step(caption, noise, hidden):
    text_embed, hidden = text_encoder(caption, hidden)
    _, fake_image = generator(text_embed, noise)
    return fake_image

Sample Debugging (unused)

In [660]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image
    return img

def imsave(images, size, path):
    # getting the pixel values between [0, 1] to save it
    return plt.imsave(path, merge(images, size)*0.5 + 0.5)

def save_images(images, size, image_path):
    return imsave(images, size, image_path)

In [661]:
def sample_generator(caption, batch_size, caption_type='id'):
    if caption_type == 'sentence':
        caption = caption2string(caption)
    caption = np.asarray(caption)
    if caption_type == 'id':
        caption = caption.astype(np.int)
    dataset = tf.data.Dataset.from_tensor_slices(caption)
    dataset = dataset.batch(batch_size)
    return dataset

In [662]:
ni = int(np.ceil(np.sqrt(hparas['BATCH_SIZE'])))
sample_size = hparas['BATCH_SIZE']
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)
sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * int(sample_size/ni) + \
                  ["this flower has petals that are yellow, white and purple and has dark lines"] * int(sample_size/ni) + \
                  ["the petals on this flower are white with a yellow center"] * int(sample_size/ni) + \
                  ["this flower has a lot of small round pink petals."] * int(sample_size/ni) + \
                  ["this flower is orange in color, and has petals that are ruffled and rounded."] * int(sample_size/ni) + \
                  ["the flower has yellow petals and the center of it is brown."] * int(sample_size/ni) + \
                  ["this flower has petals that are blue and white."] * int(sample_size/ni) +\
                  ["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/ni)

for i, sent in enumerate(sample_sentence):
    sample_sentence[i] = sent2IdList(sent)
sample_sentence = sample_generator(sample_sentence, hparas['BATCH_SIZE'], caption_type=expSettings.caption_type)

In [663]:
# test the sample dataset
for cap in sample_sentence.take(1):
    print("Caption shape:", cap.numpy().shape)
    emb, _ = text_encoder(cap, None)
    print("Caption embeddings:", emb)

Caption shape: (64,)
Caption embeddings: tf.Tensor(
[[-0.01346336  0.06881763 -0.05529514 ... -0.01518281 -0.00633275
   0.01500697]
 [-0.01346336  0.06881763 -0.05529514 ... -0.01518281 -0.00633275
   0.01500697]
 [-0.01346336  0.06881763 -0.05529514 ... -0.01518281 -0.00633275
   0.01500697]
 ...
 [-0.00790838  0.06233827  0.01107207 ... -0.02398296 -0.03450323
  -0.00868433]
 [-0.00790838  0.06233827  0.01107207 ... -0.02398296 -0.03450323
  -0.00868433]
 [-0.00790838  0.06233827  0.01107207 ... -0.02398296 -0.03450323
  -0.00868433]], shape=(64, 512), dtype=float32)


In [664]:
if not os.path.exists('samples/demo'):
    os.makedirs('samples/demo')

Training and testing

In [665]:
def train(dataset, epochs):
    # hidden state of RNN
    hidden = text_encoder.initialize_hidden_state()
    steps_per_epoch = int(hparas['N_SAMPLE']/hparas['BATCH_SIZE'])
    
    for epoch in range(hparas['N_EPOCH']):
        g_total_loss = 0
        d_total_loss = 0
        start = time.time()
        imshow = False
        
        for image, caption in dataset:
            g_loss, d_loss = train_step(image, caption, hidden, imshow=imshow)
            imshow = False
            g_total_loss += g_loss
            d_total_loss += d_loss
            
        time_tuple = time.localtime()
        time_string = time.strftime("%m/%d/%Y, %H:%M:%S", time_tuple)
            
        print("Epoch {}, gen_loss: {:.4f}, disc_loss: {:.4f}".format(epoch+1,
                                                                     g_total_loss/steps_per_epoch,
                                                                     d_total_loss/steps_per_epoch))
        print('Time for epoch {} is {:.4f} sec'.format(epoch+1, time.time()-start))
        
        print('======================================')
        
        # save the model
        if True:
            ckptManager.save()
        
        # visualization
        if (epoch + 1) % hparas['PRINT_FREQ'] == 0:
            for caption in sample_sentence:
                fake_image = test_step(caption, sample_seed, hidden)
            save_images(fake_image, [ni, ni], 'samples/demo/train_{:02d}.jpg'.format(epoch))

In [666]:
train(dataset, hparas['N_EPOCH'])

  "shape. This may consume a large amount of memory." % value)


Epoch 1, gen_loss: 0.6016, disc_loss: 1.2478
Time for epoch 1 is 41.5400 sec
Epoch 2, gen_loss: 0.6682, disc_loss: 0.3734
Time for epoch 2 is 27.1890 sec
Epoch 3, gen_loss: 0.6774, disc_loss: 0.2520
Time for epoch 3 is 26.4374 sec
Epoch 4, gen_loss: 0.4882, disc_loss: 2.6493
Time for epoch 4 is 26.1994 sec
Epoch 5, gen_loss: 0.4476, disc_loss: 1.7436
Time for epoch 5 is 25.2123 sec
Epoch 6, gen_loss: 0.4663, disc_loss: 1.5268
Time for epoch 6 is 25.3700 sec
Epoch 7, gen_loss: 0.4686, disc_loss: 1.4509
Time for epoch 7 is 25.6592 sec
Epoch 8, gen_loss: 0.4756, disc_loss: 1.4349
Time for epoch 8 is 25.9640 sec
Epoch 9, gen_loss: 0.4673, disc_loss: 1.4188
Time for epoch 9 is 26.0919 sec
Epoch 10, gen_loss: 0.4730, disc_loss: 1.4159
Time for epoch 10 is 25.3220 sec
Epoch 11, gen_loss: 0.4692, disc_loss: 1.4185
Time for epoch 11 is 25.2707 sec
Epoch 12, gen_loss: 0.4745, disc_loss: 1.4034
Time for epoch 12 is 25.0686 sec
Epoch 13, gen_loss: 0.4700, disc_loss: 1.4318
Time for epoch 13 is 25.

Epoch 71, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 71 is 25.2902 sec
Epoch 72, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 72 is 24.5724 sec
Epoch 73, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 73 is 24.8404 sec
Epoch 74, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 74 is 24.8449 sec
Epoch 75, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 75 is 25.1368 sec
Epoch 76, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 76 is 25.7529 sec
Epoch 77, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 77 is 25.1122 sec
Epoch 78, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 78 is 24.9009 sec
Epoch 79, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 79 is 25.1153 sec
Epoch 80, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 80 is 24.5983 sec
Epoch 81, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 81 is 24.8122 sec
Epoch 82, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 82 is 25.7323 sec
Epoch 83, gen_loss: 0.6931, disc_loss: 0.0000
Time f

Epoch 140, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 140 is 25.2296 sec
Epoch 141, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 141 is 25.5632 sec
Epoch 142, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 142 is 25.8658 sec
Epoch 143, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 143 is 24.8975 sec
Epoch 144, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 144 is 25.0770 sec
Epoch 145, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 145 is 24.6022 sec
Epoch 146, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 146 is 25.0293 sec
Epoch 147, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 147 is 25.3141 sec
Epoch 148, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 148 is 25.1373 sec
Epoch 149, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 149 is 25.0335 sec
Epoch 150, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 150 is 24.7834 sec
Epoch 151, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 151 is 24.8829 sec
Epoch 152, gen_loss: 0.6931,

Epoch 209, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 209 is 24.4928 sec
Epoch 210, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 210 is 24.6867 sec
Epoch 211, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 211 is 25.0904 sec
Epoch 212, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 212 is 25.1508 sec
Epoch 213, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 213 is 25.7769 sec
Epoch 214, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 214 is 25.4227 sec
Epoch 215, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 215 is 25.0832 sec
Epoch 216, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 216 is 25.3719 sec
Epoch 217, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 217 is 24.5401 sec
Epoch 218, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 218 is 25.2200 sec
Epoch 219, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 219 is 25.5985 sec
Epoch 220, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 220 is 24.7471 sec
Epoch 221, gen_loss: 0.6931,

Epoch 278, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 278 is 24.8829 sec
Epoch 279, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 279 is 24.6671 sec
Epoch 280, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 280 is 25.5843 sec
Epoch 281, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 281 is 24.9300 sec
Epoch 282, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 282 is 24.9442 sec
Epoch 283, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 283 is 25.4953 sec
Epoch 284, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 284 is 24.6973 sec
Epoch 285, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 285 is 25.3738 sec
Epoch 286, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 286 is 25.2020 sec
Epoch 287, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 287 is 24.8612 sec
Epoch 288, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 288 is 24.8689 sec
Epoch 289, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 289 is 24.7579 sec
Epoch 290, gen_loss: 0.6931,

Epoch 347, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 347 is 24.8023 sec
Epoch 348, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 348 is 24.6599 sec
Epoch 349, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 349 is 25.3430 sec
Epoch 350, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 350 is 25.2267 sec
Epoch 351, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 351 is 24.8108 sec
Epoch 352, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 352 is 25.4222 sec
Epoch 353, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 353 is 24.7967 sec
Epoch 354, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 354 is 25.0349 sec
Epoch 355, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 355 is 24.8340 sec
Epoch 356, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 356 is 24.3762 sec
Epoch 357, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 357 is 25.3111 sec
Epoch 358, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 358 is 26.0092 sec
Epoch 359, gen_loss: 0.6931,

Epoch 416, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 416 is 24.5186 sec
Epoch 417, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 417 is 24.7183 sec
Epoch 418, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 418 is 24.9083 sec
Epoch 419, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 419 is 24.5504 sec
Epoch 420, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 420 is 25.3769 sec
Epoch 421, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 421 is 24.8516 sec
Epoch 422, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 422 is 25.0270 sec
Epoch 423, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 423 is 24.9721 sec
Epoch 424, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 424 is 24.8761 sec
Epoch 425, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 425 is 25.2567 sec
Epoch 426, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 426 is 25.7530 sec
Epoch 427, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 427 is 24.9717 sec
Epoch 428, gen_loss: 0.6931,

Epoch 485, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 485 is 25.0354 sec
Epoch 486, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 486 is 24.7497 sec
Epoch 487, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 487 is 24.4038 sec
Epoch 488, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 488 is 24.5250 sec
Epoch 489, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 489 is 25.6546 sec
Epoch 490, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 490 is 24.8781 sec
Epoch 491, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 491 is 24.8888 sec
Epoch 492, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 492 is 24.6202 sec
Epoch 493, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 493 is 24.7230 sec
Epoch 494, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 494 is 24.7520 sec
Epoch 495, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 495 is 24.5731 sec
Epoch 496, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 496 is 24.7539 sec
Epoch 497, gen_loss: 0.6931,

Epoch 554, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 554 is 25.3071 sec
Epoch 555, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 555 is 25.1931 sec
Epoch 556, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 556 is 25.5291 sec
Epoch 557, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 557 is 25.3081 sec
Epoch 558, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 558 is 24.9676 sec
Epoch 559, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 559 is 24.5102 sec
Epoch 560, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 560 is 25.3340 sec
Epoch 561, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 561 is 25.0071 sec
Epoch 562, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 562 is 24.7358 sec
Epoch 563, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 563 is 24.4354 sec
Epoch 564, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 564 is 25.4522 sec
Epoch 565, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 565 is 24.5869 sec
Epoch 566, gen_loss: 0.6931,

Epoch 623, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 623 is 24.5730 sec
Epoch 624, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 624 is 25.2028 sec
Epoch 625, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 625 is 24.6240 sec
Epoch 626, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 626 is 24.6118 sec
Epoch 627, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 627 is 25.1500 sec
Epoch 628, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 628 is 24.2388 sec
Epoch 629, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 629 is 24.5730 sec
Epoch 630, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 630 is 24.6568 sec
Epoch 631, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 631 is 25.0801 sec
Epoch 632, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 632 is 24.4569 sec
Epoch 633, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 633 is 25.5512 sec
Epoch 634, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 634 is 24.6063 sec
Epoch 635, gen_loss: 0.6931,

Epoch 692, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 692 is 25.2313 sec
Epoch 693, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 693 is 24.5053 sec
Epoch 694, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 694 is 25.4425 sec
Epoch 695, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 695 is 24.8048 sec
Epoch 696, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 696 is 25.3962 sec
Epoch 697, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 697 is 24.4190 sec
Epoch 698, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 698 is 25.1655 sec
Epoch 699, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 699 is 24.6479 sec
Epoch 700, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 700 is 24.5501 sec
Epoch 701, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 701 is 24.4978 sec
Epoch 702, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 702 is 24.7574 sec
Epoch 703, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 703 is 24.5775 sec
Epoch 704, gen_loss: 0.6931,

Epoch 761, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 761 is 24.6132 sec
Epoch 762, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 762 is 24.1999 sec
Epoch 763, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 763 is 24.3389 sec
Epoch 764, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 764 is 24.4272 sec
Epoch 765, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 765 is 24.3404 sec
Epoch 766, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 766 is 24.4181 sec
Epoch 767, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 767 is 24.5013 sec
Epoch 768, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 768 is 24.2288 sec
Epoch 769, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 769 is 24.4340 sec
Epoch 770, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 770 is 24.3284 sec
Epoch 771, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 771 is 24.7258 sec
Epoch 772, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 772 is 24.4663 sec
Epoch 773, gen_loss: 0.6931,

Epoch 830, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 830 is 24.2629 sec
Epoch 831, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 831 is 24.1672 sec
Epoch 832, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 832 is 24.4127 sec
Epoch 833, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 833 is 24.3922 sec
Epoch 834, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 834 is 24.6042 sec
Epoch 835, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 835 is 24.5469 sec
Epoch 836, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 836 is 24.4034 sec
Epoch 837, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 837 is 24.2638 sec
Epoch 838, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 838 is 24.1900 sec
Epoch 839, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 839 is 24.2345 sec
Epoch 840, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 840 is 24.2879 sec
Epoch 841, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 841 is 24.4465 sec
Epoch 842, gen_loss: 0.6931,

Epoch 899, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 899 is 24.2511 sec
Epoch 900, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 900 is 24.5892 sec
Epoch 901, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 901 is 24.7122 sec
Epoch 902, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 902 is 24.5102 sec
Epoch 903, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 903 is 24.2473 sec
Epoch 904, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 904 is 24.8383 sec
Epoch 905, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 905 is 24.3196 sec
Epoch 906, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 906 is 24.3981 sec
Epoch 907, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 907 is 24.2156 sec
Epoch 908, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 908 is 24.6339 sec
Epoch 909, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 909 is 24.7737 sec
Epoch 910, gen_loss: 0.6931, disc_loss: 0.0000
Time for epoch 910 is 24.2436 sec
Epoch 911, gen_loss: 0.6931,

KeyboardInterrupt: 

In [667]:
def test_caption2string(ls):
    return " ".join([id2word_dict[idx] for idx in ls]).strip().split(' <PAD>')[0]

def testing_data_generator(caption, index, caption_type='id'):
    if caption_type == 'id':
        caption = tf.cast(caption, tf.float32)
    return caption, index

def testing_dataset_generator(batch_size, data_generator, caption_type='id'):
    data = pd.read_pickle('./dataset/testData.pkl')
    
    if caption_type == 'sentence':
        data['Captions_string'] = data['Captions'].apply(test_caption2string)
        captions = data['Captions_string'].values
    elif caption_type == 'id':
        captions = data['Captions'].values
        
    caption = []
    for i in range(len(captions)):
        caption.append(captions[i])
    caption = np.asarray(caption)
    
    if caption_type == 'id':
        caption = caption.astype(np.int)
        
    datagen_func = lambda cap, img: data_generator(cap, img, caption_type=caption_type)
        
    index = data['ID'].values
    index = np.asarray(index)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, index))
    dataset = dataset.map(datagen_func, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat().batch(batch_size)
    
    return dataset

In [668]:
testing_dataset = testing_dataset_generator(hparas['BATCH_SIZE'], testing_data_generator, caption_type=expSettings.caption_type)

In [669]:
# test the testing dataset
for cap, img in testing_dataset.take(1):
    print("Image shape:", img.numpy().shape)
    print("Caption shape:", cap.numpy().shape)

Image shape: (64,)
Caption shape: (64,)


In [670]:
data = pd.read_pickle('./dataset/testData.pkl')
captions = data['Captions'].values

NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / hparas['BATCH_SIZE'])

In [671]:
if not os.path.exists('./inference/demo'):
    os.makedirs('./inference/demo')

In [672]:
def inference(dataset):
    hidden = text_encoder.initialize_hidden_state()
    sample_size = hparas['BATCH_SIZE']
    sample_seed = np.random.normal(loc=0.0, scale=0.1, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)
    print(sample_seed[0:3, :])
    
    step = 0
    start = time.time()
    for captions, idx in dataset:
        if step > EPOCH_TEST:
            break
        
        fake_image = test_step(captions, sample_seed, hidden)
        step += 1
        for i in range(hparas['BATCH_SIZE']):
            plt.imsave('./inference/demo/inference_{:04d}.jpg'.format(idx[i]), fake_image[i].numpy()*0.5 + 0.5)
            
            if i == 0 and step == 1: 
                #print(captions)
                text_embed_t, hidden_t = text_encoder(captions, hidden)
                #print(text_embed_t)
                print(fake_image[0:1, 0:5, 0:5, :])
                img_logits, _, debug = generator(text_embed_t, sample_seed, debug_output=True)
                print(debug[0][0:3, 0:5, 0:5, 0:5])
                print(debug[1][0:3, 0:5, 0:5, 0:5])
                print(debug[2][0:3, 0:5, 0:5, 0:5])
                print(debug[3][0:3, 0:5, 0:5, 0:5])
                print(debug[4][0:3, 0:5, 0:5, 0:5])
                print(img_logits[0:3, 0:5, 0:5, :])
                pred_logit, pred = discriminator(fake_image, text_embed_t)
                #print(pred_logit)
                
            
    print('Time for inference is {:.4f} sec'.format(time.time()-start))

In [673]:
#checkpoint.restore(checkpoint_dir + f'/ckpt-50')
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
    print(f"Restoring from {latest_checkpoint}")
    checkpoint.restore(latest_checkpoint)
else:
    print("No checkpoint found.")

Restoring from ./checkpoints/demo\ckpt-963


In [674]:
inference(testing_dataset)

[[-0.25735587  0.02635855  0.07098211 ...  0.11278663  0.08181195
   0.04097688]
 [-0.24220537 -0.10393385 -0.04679523 ... -0.13002335  0.01695522
  -0.08721623]
 [-0.09068998 -0.0250848   0.10610484 ... -0.23031023  0.22233571
  -0.11565182]]
tf.Tensor(
[[[[ 0.05371411  0.07407826  0.01101015]
   [-0.23510526 -0.02537002 -0.23780264]
   [ 0.7734425  -0.0782879  -0.09224059]
   [-0.09640063 -0.11736421 -0.18235064]
   [-0.3867997  -0.30161613 -0.31821412]]

  [[-0.02428844 -0.00901427 -0.01478982]
   [-0.8219737  -0.74098504 -0.751086  ]
   [-0.92784774 -0.48945045 -0.78000706]
   [-0.9865599  -0.8891686  -0.94828457]
   [-0.9906214  -0.8981045  -0.8246935 ]]

  [[-0.0325912  -0.01237387 -0.02101737]
   [-0.79343396 -0.20068747 -0.65321845]
   [-0.96634734 -0.7707148  -0.8599473 ]
   [-0.9650226  -0.7101485  -0.8881296 ]
   [-0.9939072  -0.91076356 -0.9596795 ]]

  [[-0.04732414 -0.02123832 -0.02856526]
   [-0.9788062  -0.84277856 -0.97331697]
   [-0.99145055 -0.60533476 -0.93648994]
 

Time for inference is 4.2369 sec


In [675]:
for var in generator.starter.variables:
    print(var.name, var.numpy())

generator_14/dense_block_14/dense_90/kernel:0 [[-0.01977373  0.0257942   0.02203238 ...  0.04614797  0.03235683
  -0.03627039]
 [ 0.02929055 -0.05542327 -0.03276846 ... -0.00220141 -0.04890126
  -0.03045244]
 [ 0.01022856 -0.05964788 -0.0528914  ... -0.00563824  0.03960341
  -0.03896577]
 ...
 [ 0.03309955 -0.00184861  0.01047124 ...  0.04798859  0.08177074
  -0.05541556]
 [-0.03271682 -0.02365159 -0.01691599 ... -0.02826412 -0.02848577
   0.00127587]
 [-0.03394162 -0.02770595  0.04155768 ... -0.00242701 -0.03895332
   0.02669442]]
generator_14/dense_block_14/dense_90/bias:0 [-0.00803091 -0.00534232 -0.00174159 ... -0.01185318  0.00154424
 -0.0158552 ]
generator_14/dense_block_14/batch_normalization_153/gamma:0 [0.950655   0.9673666  0.9671377  ... 0.92883193 0.98732436 0.94109476]
generator_14/dense_block_14/batch_normalization_153/beta:0 [-0.00784559 -0.00570957 -0.00161753 ... -0.01253364  0.00215867
 -0.01634489]
generator_14/dense_block_14/batch_normalization_153/moving_mean:0 [0.

In [676]:
%cd ./testing
!python inception_score.py ../inference/demo output.csv 21
%cd ../

C:\Users\User\Courses\24aut_deep_learning\24aut-deep-learning\deep-learning-comp3\testing
1 Physical GPUs, 1 Logical GPUs
--------------Evaluation Success-----------------
C:\Users\User\Courses\24aut_deep_learning\24aut-deep-learning\deep-learning-comp3


In [677]:
text_encoder.summary()

Model: "text_encoder_14"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_14 (Embedding)    multiple                  0 (unused)
                                                                 
 gru_14 (GRU)                multiple                  0 (unused)
                                                                 
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________


In [678]:
generator.summary()

Model: "generator_14"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_28 (Flatten)        multiple                  0 (unused)
                                                                 
 dense_88 (Dense)            multiple                  0 (unused)
                                                                 
 dense_89 (Dense)            multiple                  0 (unused)
                                                                 
 dense_block_14 (dense_block  multiple                 1053696   
 )                                                               
                                                                 
 deconv_block_70 (deconv_blo  multiple                 1049856   
 ck)                                                             
                                                                 
 deconv_block_71 (deconv_blo  multiple                

In [679]:
discriminator.summary()

Model: "discriminator_14"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv_block_77 (conv_block)  multiple                  8192      
                                                                 
 conv_block_78 (conv_block)  multiple                  147776    
                                                                 
 flatten_29 (Flatten)        multiple                  0         
                                                                 
 dense_91 (Dense)            multiple                  65664     
                                                                 
 dense_92 (Dense)            multiple                  33554560  
                                                                 
 dense_93 (Dense)            multiple                  257       
                                                                 
Total params: 33,776,449
Trainable params: 33,775,