# Competition 3: Team 21

112062649 王俊皓

112062650 廖士傑

##  Reverse Image Caption

In [437]:
# project settings
# aug_type: None does nothing; 'explode' considers all possible caption-image pair;
# enc / gen / dis: True means using a more complex setting, False means using a setting that is closer to template
# enc_do_batchnorm: encoder will perform batch normalization to make the text encoding more variant
# dis_backbone: the type of backbone in discriminator, can be 'resnet', 'vgg', 'simple'
# delete_checkpoint: deletes all checkpoint files before training

class experimental_settings:
    def __init__(self,
                 aug_type='explode',
                 enc=True,
                 enc_do_batchnorm=False,
                 gen=True,
                 dis=True,
                 dis_backbone='simple',
                 delete_checkpoint=False):
        self.aug_type = aug_type
        self.enc = enc
        self.gen = gen
        self.dis = dis
        self.dis_backbone = dis_backbone
        self.enc_do_batchnorm = enc_do_batchnorm
        self.delete_checkpoint = delete_checkpoint # not implemented yet
        
        # ============================ #
        # automatic
        # ============================ #
        
        self.caption_type = 'sentence' if self.enc else 'id'


expSettings = experimental_settings()

## Import

In [438]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras import layers
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
import string
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL
import random
import time
from pathlib import Path
import math

import re
from IPython import display

GPU check

In [439]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Restrict TensorFlow to only use the first GPU
        tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


## Loading Data

In [440]:
dictionary_path = './dictionary'
vocab = np.load(dictionary_path + '/vocab.npy')
print('there are {} vocabularies in total'.format(len(vocab)))

word2Id_dict = dict(np.load(dictionary_path + '/word2Id.npy'))
id2word_dict = dict(np.load(dictionary_path + '/id2Word.npy'))
print('Word to id mapping, for example: %s -> %s' % ('flower', word2Id_dict['flower']))
print('Id to word mapping, for example: %s -> %s' % ('1', id2word_dict['1']))
print('Tokens: <PAD>: %s; <RARE>: %s' % (word2Id_dict['<PAD>'], word2Id_dict['<RARE>']))

there are 5427 vocabularies in total
Word to id mapping, for example: flower -> 1
Id to word mapping, for example: 1 -> flower
Tokens: <PAD>: 5427; <RARE>: 5428


In [441]:
def sent2IdList(line, MAX_SEQ_LENGTH=20):
    MAX_SEQ_LIMIT = MAX_SEQ_LENGTH
    padding = 0
    
    # data preprocessing, remove all puntuation in the texts
    prep_line = re.sub('[%s]' % re.escape(string.punctuation), ' ', line.rstrip())
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('-', ' ')
    prep_line = prep_line.replace('  ', ' ')
    prep_line = prep_line.replace('.', '')
    tokens = prep_line.split(' ')
    tokens = [
        tokens[i] for i in range(len(tokens))
        if tokens[i] != ' ' and tokens[i] != ''
    ]
    l = len(tokens)
    padding = MAX_SEQ_LIMIT - l
    
    # make sure length of each text is equal to MAX_SEQ_LENGTH, and replace the less common word with <RARE> token
    for i in range(padding):
        tokens.append('<PAD>')
    line = [
        word2Id_dict[tokens[k]]
        if tokens[k] in word2Id_dict else word2Id_dict['<RARE>']
        for k in range(len(tokens))
    ]

    return line

text = "the flower shown has yellow anther red pistil and bright red petals."
print(text)
print(sent2IdList(text))

the flower shown has yellow anther red pistil and bright red petals.
['9', '1', '82', '5', '11', '70', '20', '31', '3', '29', '20', '2', '5427', '5427', '5427', '5427', '5427', '5427', '5427', '5427']


In [442]:
@tf.function
def id2Sent(ids):
    return " ".join([id2word_dict[idx] for idx in ids]).strip()

#def batch_id2Sent(batch_ids):
    #return [id2Sent(ids) for ids in batch_ids]
    
def batch_id2Sent(batch_ids):
    def process_single(ids):
        # Convert a single tensor of IDs to a sentence
        ids = ids.numpy()  # Convert Tensor to NumPy
        sentence = " ".join([id2word_dict.get(idx, "<UNK>") for idx in ids])  # Handle unknown IDs
        return sentence

    # Use tf.py_function to apply Python function inside the TensorFlow graph
    sentences = tf.map_fn(
        lambda ids: tf.py_function(process_single, [ids], tf.string),
        batch_ids,
        fn_output_signature=tf.string
    )
    return sentences


print(sent2IdList(text))
print(id2Sent(sent2IdList(text)))

['9', '1', '82', '5', '11', '70', '20', '31', '3', '29', '20', '2', '5427', '5427', '5427', '5427', '5427', '5427', '5427', '5427']
tf.Tensor(b'the flower shown has yellow anther red pistil and bright red petals <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>', shape=(), dtype=string)


In [443]:
data_path = './dataset'
text2ImgData = pd.read_pickle(data_path + '/text2ImgData.pkl')
num_training_sample = len(text2ImgData)
n_images_train = num_training_sample
print('There are %d image in training data' % (n_images_train))

There are 7370 image in training data


In [444]:
# data augmentation by including all possible captions

if expSettings.aug_type == 'explode':
    text2ImgData = text2ImgData.explode('Captions', ignore_index=True)
print(len(text2ImgData))
text2ImgData.head(5)

70504


Unnamed: 0,Captions,ImagePath
0,"[9, 2, 17, 9, 1, 6, 14, 13, 18, 3, 41, 8, 11, ...",./102flowers/image_06734.jpg
1,"[4, 1, 15, 14, 3, 12, 13, 18, 7, 2, 10, 6, 123...",./102flowers/image_06734.jpg
2,"[9, 16, 2, 41, 149, 17, 12, 7, 12, 70, 3, 120,...",./102flowers/image_06734.jpg
3,"[4, 1, 5, 26, 14, 2, 3, 8, 12, 30, 13, 9, 23, ...",./102flowers/image_06734.jpg
4,"[4, 1, 5, 2, 10, 6, 14, 3, 5, 8, 11, 19, 5427,...",./102flowers/image_06734.jpg


In [445]:
def caption2string(cap, aug_type=None):
    
    if aug_type == None:
        output = []
        for sen in cap:
            s = " ".join([id2word_dict[idx] for idx in sen]).strip()
            output.append(s.split(' <PAD>')[0])
            
    elif aug_type == 'explode':
        output = [" ".join([id2word_dict[idx] for idx in cap]).strip().split(' <PAD>')[0]]
        
    else:
        raise ValueError('aug_type must be None or \'explode\'.')
        
    return output

# adding caption as strings
c2s_fn = lambda cap: caption2string(cap, aug_type=expSettings.aug_type)
text2ImgData['Captions_string'] = text2ImgData['Captions'].apply(c2s_fn)

In [446]:
text2ImgData.head(5)

Unnamed: 0,Captions,ImagePath,Captions_string
0,"[9, 2, 17, 9, 1, 6, 14, 13, 18, 3, 41, 8, 11, ...",./102flowers/image_06734.jpg,[the petals of the flower are pink in color an...
1,"[4, 1, 15, 14, 3, 12, 13, 18, 7, 2, 10, 6, 123...",./102flowers/image_06734.jpg,[this flower is pink and white in color with p...
2,"[9, 16, 2, 41, 149, 17, 12, 7, 12, 70, 3, 120,...",./102flowers/image_06734.jpg,[the purple petals have shades of white with w...
3,"[4, 1, 5, 26, 14, 2, 3, 8, 12, 30, 13, 9, 23, ...",./102flowers/image_06734.jpg,[this flower has large pink petals and a white...
4,"[4, 1, 5, 2, 10, 6, 14, 3, 5, 8, 11, 19, 5427,...",./102flowers/image_06734.jpg,[this flower has petals that are pink and has ...


In [447]:
text2ImgData['Captions_string'][:1].tolist()

[['the petals of the flower are pink in color and have a yellow center']]

In [448]:
# in this competition, you have to generate image in size 64x64x3
IMAGE_SIZE = 64
IMAGE_HEIGHT = IMAGE_SIZE
IMAGE_WIDTH = IMAGE_SIZE
IMAGE_CHANNEL = 3

def training_data_generator(caption, image_path, caption_type='id'):
    # load in the image according to image path
    img = tf.io.read_file(image_path)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img.set_shape([None, None, 3])
    img = tf.image.resize(img, size=[IMAGE_HEIGHT, IMAGE_WIDTH])
    img.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL])
    if caption_type == 'id':
        caption = tf.cast(caption, tf.int32)
    elif caption_type == 'sentence':
        caption = tf.convert_to_tensor(caption, dtype=tf.string)

    return img, caption

def dataset_generator(filenames, batch_size, data_generator, caption_type='id'):
    # load the training data into two NumPy arrays
    if filenames != None:
        df = pd.read_pickle(filenames)
    else:
        df = text2ImgData
    
    if caption_type == 'id':
        captions = df['Captions'].values
    elif caption_type == 'sentence':
        captions = df['Captions_string'].values
    else:
        raise ValueError('for dataset_generator, caption_type= should be \'id\' or \'sentence\'.')
        
    caption = []
    # each image has 1 to 10 corresponding captions
    # we choose one of them randomly for training
    
    # ============================================ #
    # TODO: augmentation
    # idea 1 (difficulty: easy) (trying)
    #     training data has multiple captions, right now it picks a random one.
    #     we can make it so that every caption is an entry and multiple captions link to the same image.
    # idea 2 (difficulty: medium)
    #     after text embedding, use the average of 2 caption embeddings to generate a new caption.
    #     the data does not need to have an image tied to it, it just have the label 0 (fake image).
    # ============================================ #
    for i in range(len(captions)):
        if expSettings.aug_type == None:
            caption.append(random.choice(captions[i]))
        elif expSettings.aug_type == 'explode':
            caption.append(captions[i][0])
    caption = np.asarray(caption)
    
    if caption_type == 'id':
        caption = caption.astype(np.int)
        
    image_path = df['ImagePath'].values
    
    # assume that each row of `features` corresponds to the same row as `labels`.
    assert caption.shape[0] == image_path.shape[0]
    
    datagen_func = lambda cap, img: data_generator(cap, img, caption_type=caption_type)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, image_path))
    dataset = dataset.map(datagen_func, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.shuffle(min(1000, len(caption))).batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return dataset

In [449]:
BATCH_SIZE = 64
dataset = dataset_generator(
    #data_path + '/text2ImgData.pkl',
    None,
    BATCH_SIZE, 
    training_data_generator, 
    caption_type=expSettings.caption_type)

In [450]:
# dataset testing ground
for img, cap in dataset.take(1):
    print("Image shape:", img.numpy().shape)
    print("Caption shape:", cap.numpy().shape)

Image shape: (64, 64, 64, 3)
Caption shape: (64,)


In [451]:
from tensorflow.keras.layers import Conv2DTranspose, Conv2D, BatchNormalization, LeakyReLU, Dense, Dropout
from tensorflow.keras.initializers import HeNormal

# custom layers
class flattened_dense(tf.keras.layers.Layer):
    """
    a dense layer that is made compatible with convolution layers
    by flattening the input first and followed by a dense layer.
    """
    def __init__(self, channels=64, kernel_initializer="glorot_uniform"):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(channels, kernel_initializer=kernel_initializer)
        
    def call(self, inputs):
        fl = self.flatten(inputs)
        return self.dense(fl)
    
class conv_block(tf.keras.layers.Layer):
    """
    a convolution layer with batch normalization and leaky relu activation
    """
    def __init__(self, filters=128, kernel_size=1, strides=1, kernel_initializer=HeNormal()):
        super().__init__()
        self.conv = Conv2D(filters=filters,
                           kernel_size = (kernel_size, kernel_size),
                           strides=(strides, strides),
                           padding='same',
                           kernel_initializer=kernel_initializer)
        self.bn = BatchNormalization()
        self.activation = LeakyReLU(alpha=0.1)
    
    def call(self, inputs):
        return self.activation(self.bn(self.conv(inputs)))
    
class deconv_block(tf.keras.layers.Layer):
    """
    a deconvolution layer with batch normalization and leaky relu activation
    """
    def __init__(self, filters=128, kernel_size=4, strides=2, kernel_initializer='glorot_uniform'):
        super().__init__()
        self.deconv = Conv2DTranspose(filters=filters,
                                    kernel_size = (kernel_size, kernel_size),
                                    strides=(strides, strides),
                                    padding='same',
                                    kernel_initializer=kernel_initializer)
        self.bn = BatchNormalization()
        self.activation = LeakyReLU(alpha=0.1)
    
    def call(self, inputs):
        return self.activation(self.bn(self.deconv(inputs)))
    
class dense_block(tf.keras.layers.Layer):
    """
    a dense layer with batch normalization and leaky relu activation
    """
    def __init__(self, filters=128, kernel_initializer='glorot_uniform'):
        super().__init__()
        self.d = Dense(filters, kernel_initializer=kernel_initializer)
        self.bn = BatchNormalization()
        self.activation = LeakyReLU(alpha=0.1)
    
    def call(self, inputs):
        outputs = self.d(inputs)
        outputs = self.bn(outputs)
        return self.activation(outputs)

In [452]:
import tensorflow_hub as hub

class TextEncoder(tf.keras.Model):
    """
    Encode text (a caption) into hidden representation
    input: text, which is a list of ids
    output: embedding, or hidden representation of input text in dimension of RNN_HIDDEN_SIZE
    """
    def __init__(self, hparas, experimental=False, do_batchnorm=False):
        super(TextEncoder, self).__init__()
        self.exp=experimental
        self.do_batchnorm = do_batchnorm
        self.hparas = hparas
        self.batch_size = self.hparas['BATCH_SIZE']
        
        # embedding with tensorflow API
        self.embedding = layers.Embedding(self.hparas['VOCAB_SIZE'], self.hparas['EMBED_DIM'])
        # RNN, here we use GRU cell, another common RNN cell similar to LSTM
        self.gru = layers.GRU(self.hparas['RNN_HIDDEN_SIZE'],
                              return_sequences=True,
                              return_state=True,
                              recurrent_initializer='glorot_uniform')
        if self.exp:
            self.embed = hub.load('./checkpoints/universal_sentence_encoder')
    
    def call(self, text, hidden):
        if self.exp:
            with tf.device('/CPU:0'): # TODO if you find a way to use GPU, go for it.
                output_last = self.embed(text)
                
            state = hidden # not updating state for compatibility reasons
            
        else:
            text = self.embedding(text)
            output, state = self.gru(text, initial_state = hidden)
            output_last = output[:, -1, :]
        
        # normalization in-batch
        if self.do_batchnorm:
            mean = tf.reduce_mean(output_last, axis=0, keepdims=True)  # Mean across the batch
            std = tf.math.reduce_std(output_last, axis=0, keepdims=True)  # Std across the batch
            normalized = (output_last - mean) / (std + 1e-6)  # Avoid division by zero
        else:
            normalized = output_last
        
        return normalized, state
    
    def initialize_hidden_state(self):
        return tf.zeros((self.hparas['BATCH_SIZE'], self.hparas['RNN_HIDDEN_SIZE']))

In [453]:
class Generator(tf.keras.Model):
    """
    Generate fake image based on given text(hidden representation) and noise z
    input: text and noise
    output: fake image with size 64*64*3
    """
    def __init__(self, hparas, experimental=False):
        super(Generator, self).__init__()
        self.exp = experimental
        self.hparas = hparas
        self.flatten = tf.keras.layers.Flatten()
        self.d1 = tf.keras.layers.Dense(self.hparas['DENSE_DIM'], kernel_initializer="glorot_uniform")
        self.d2 = tf.keras.layers.Dense(64*64*3, kernel_initializer="glorot_uniform")
        if self.exp:
            self.deconv_depth = int(math.log(IMAGE_SIZE, 2)) - 1
            self.starter = dense_block(filters=2*2*512)
            self.deconv = [
                deconv_block(filters=512, kernel_initializer=HeNormal()),
                #conv_block(filters=512, kernel_size=1, strides=1),
                deconv_block(filters=256, kernel_initializer=HeNormal()),
                #conv_block(filters=256, kernel_size=1, strides=1),
                #Dropout(0.2),
                deconv_block(filters=128, kernel_initializer=HeNormal()),
                #conv_block(filters=128, kernel_size=1, strides=1),
                #Dropout(0.2),
                deconv_block(filters=64, kernel_initializer=HeNormal()),
                #conv_block(filters=64, kernel_size=1, strides=1),
                deconv_block(filters=32, kernel_initializer=HeNormal()),
                conv_block(filters=32, kernel_size=1, strides=1),
                deconv_block(filters=16, strides=1, kernel_initializer=HeNormal())
            ]
            self.headf = conv_block(filters=3, kernel_size=1, strides=1)
            
    def call(self, text, noise_z, debug_output=False, training=False):
        # deconvolution
        if self.exp:
            noisy_text = tf.concat([text, noise_z], axis=1) * 10 # amplify the input a bit, they seem fairly close to 0.
            img = self.starter(noisy_text)
            
            img = tf.reshape(img, [-1, 2, 2, 512])
            debug = []
            for layer in self.deconv:
                if isinstance(layer, tf.keras.layers.Dropout) and training:
                    continue
                debug.append(img)
                img = layer(img)

            img = self.headf(img)
            logits = tf.reshape(img, [-1, IMAGE_SIZE, IMAGE_SIZE, 3])
            output = tf.nn.tanh(logits)
        
        # concatenate input text and random noise
        else:
            text = self.flatten(text)
            text = self.d1(text)
            text = tf.nn.leaky_relu(text)
            text_concat = tf.concat([noise_z, text], axis=1)
            text_concat = self.d2(text_concat)
        
            logits = tf.reshape(text_concat, [-1, 64, 64, 3])
            output = tf.nn.tanh(logits)
            debug_output = output
        
        if debug_output:
            return logits, output, debug
        else:
            return logits, output

In [454]:
from tensorflow.keras.applications import ResNet50, VGG16

class Discriminator(tf.keras.Model):
    """
    Differentiate the real and fake image
    input: image and corresponding text
    output: labels, the real image should be 1, while the fake should be 0
    """
    def __init__(self, hparas, experimental=False, backbone='simple'):
        super(Discriminator, self).__init__()
        self.exp = experimental
        self.bbtype = backbone
        self.hparas = hparas
        if self.exp:
            if self.bbtype == 'resnet':
                self.resnet_base = ResNet50(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), weights='imagenet', include_top=False)
                for layer in self.resnet_base.layers:
                    layer.trainable = False
            elif self.bbtype == 'vgg':
                self.vgg_base = VGG16(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), weights='imagenet', include_top=False)
                for layer in self.vgg_base.layers:
                    layer.trainable = False
            elif self.bbtype == 'simple':
                self.conv = [
                    conv_block(filters=512, kernel_size=3, strides=2),
                    #conv_block(filters=512, kernel_size=3, strides=1),
                    conv_block(filters=512, kernel_size=1, strides=1),
                    Dropout(hparas['DIS_DROPOUT']),
                    conv_block(filters=256, kernel_size=3, strides=2),
                    #conv_block(filters=256, kernel_size=3, strides=1),
                    conv_block(filters=256, kernel_size=1, strides=1),
                    Dropout(hparas['DIS_DROPOUT']),
                    conv_block(filters=128, kernel_size=3, strides=2),
                    conv_block(filters=128, kernel_size=3, strides=1),
                    conv_block(filters=128, kernel_size=1, strides=1),
                    #Dropout(hparas['DIS_DROPOUT']),
                    conv_block(filters=64, kernel_size=3, strides=2),
                    #conv_block(filters=64, kernel_size=3, strides=1),
                    conv_block(filters=64, kernel_size=1, strides=1),
                    Dropout(hparas['DIS_DROPOUT'] * 2 / 3)
                ]
                # more simple one
                #self.conv1 = conv_block(filters=256, kernel_size=3, strides=1)
                #self.conv2 = conv_block(filters=64, kernel_size=3, strides=1)
        self.flatten = tf.keras.layers.Flatten()
        self.d_text = tf.keras.layers.Dense(self.hparas['DENSE_DIM'])
        self.d_img = tf.keras.layers.Dense(self.hparas['DENSE_DIM'])
        self.d = tf.keras.layers.Dense(1)
    
    def call(self, img, text, training=False):
        text = self.flatten(text)
        text = self.d_text(text)
        text = tf.nn.leaky_relu(text)
        
        if self.exp:
            if self.bbtype == 'resnet':
                img = self.resnet_base(img)
            elif self.bbtype == 'vgg':
                img = self.vgg_base(img)
            elif self.bbtype == 'simple':
                for layer in self.conv: # see init for spec
                    if isinstance(layer, Dropout) and training:
                        continue
                    img = layer(img)
        img = self.flatten(img)
        img = self.d_img(img)
        img = tf.nn.leaky_relu(img)
        
        # concatenate image with paired text
        img_text = tf.concat([text, img], axis=1)
        
        logits = self.d(img_text)
        output = tf.nn.sigmoid(logits)
        
        return logits, output

Parameters and settings

In [455]:
hparas = {
    'MAX_SEQ_LENGTH': 20,                     # maximum sequence length
    'EMBED_DIM': 256,                         # word embedding dimension
    'VOCAB_SIZE': len(word2Id_dict),          # size of dictionary of captions
    'RNN_HIDDEN_SIZE': 128,                   # number of RNN neurons
    'Z_DIM': 512,                             # random noise z dimension
    'DENSE_DIM': 128,                         # number of neurons in dense layer
    'IMAGE_SIZE': [64, 64, 3],                # render image size
    'BATCH_SIZE': 64,
    'DIS_DROPOUT': 0.5,                       # chance of dropout in discriminator
    'LR_GEN': 1e-4,
    'LR_DIS': 1e-5,
    'LR_DECAY': 0.5,                          # unused
    'CLIPNORM': 0.1,
    'BETA_1': 0.5,
    'N_EPOCH': 500,                           # number of epoch for demo
    'N_SAMPLE': num_training_sample,          # size of training data
    'CHECKPOINTS_DIR': './checkpoints/demo',  # checkpoint path
    'PRINT_FREQ': 1                           # printing frequency of loss
}

In [456]:
text_encoder = TextEncoder(hparas, 
                           experimental=expSettings.enc,
                           do_batchnorm=expSettings.enc_do_batchnorm)

generator = Generator(hparas,
                      experimental=expSettings.gen)

discriminator = Discriminator(hparas,
                              experimental=expSettings.dis,
                              backbone=expSettings.dis_backbone)

In [457]:
# test text encoder
for img, cap in dataset.take(1):
    print("Image shape:", img.numpy().shape)
    print("Caption shape:", cap.shape)
    with tf.device('/CPU:0'):
        output, _ = text_encoder(cap, text_encoder.initialize_hidden_state())
        print("Caption embed shape:", output.shape)

Image shape: (64, 64, 64, 3)
Caption shape: (64,)
Caption embed shape: (64, 512)


In [458]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [459]:
from tensorflow.nn import sigmoid_cross_entropy_with_logits

def discriminator_loss(real_logits, fake_logits):
    # output value of real image should be 1
    real_loss = sigmoid_cross_entropy_with_logits(tf.ones_like(real_logits), real_logits)
    # output value of fake image should be 0
    fake_loss = sigmoid_cross_entropy_with_logits(tf.zeros_like(fake_logits), fake_logits)
    total_loss = tf.reduce_mean(real_loss) + tf.reduce_mean(fake_loss)
                                 
    return total_loss
def generator_loss(fake_output):
    # output value of fake image should be 0
    return tf.reduce_mean(sigmoid_cross_entropy_with_logits(tf.ones_like(fake_output), fake_output))

In [460]:
# we use seperated optimizers for training generator and discriminator
generator_optimizer = tf.keras.optimizers.Adam(hparas['LR_GEN'], clipvalue=hparas['CLIPNORM'])
discriminator_optimizer = tf.keras.optimizers.Adam(hparas['LR_DIS'], clipvalue=hparas['CLIPNORM'])

In [461]:
if expSettings.delete_checkpoint:
    for f in os.listdir(hparas['CHECKPOINTS_DIR']):
        file_path = os.path.join(hparas['CHECKPOINTS_DIR'], f)
        if os.path.isfile(file_path):
            os.unlink(file_path)

# one benefit of tf.train.Checkpoint() API is we can save everything seperately
checkpoint_dir = hparas['CHECKPOINTS_DIR']
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 text_encoder=text_encoder,
                                 generator=generator,
                                 discriminator=discriminator)
ckptManager = tf.train.CheckpointManager(
    checkpoint, directory=hparas['CHECKPOINTS_DIR'], max_to_keep=5)

In [462]:
@tf.function
def train_step(real_image, caption, hidden, imshow=False):
    # random noise for generator
    noise = tf.random.normal(shape=[hparas['BATCH_SIZE'], hparas['Z_DIM']], mean=0.0, stddev=0.1)
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        text_embed, hidden = text_encoder(caption, hidden)
        _, fake_image = generator(text_embed, noise, training=True)
        if imshow:
            plt.imshow(fake_image[0])

        real_logits, real_output = discriminator(real_image, text_embed, training=True)
        fake_logits, fake_output = discriminator(fake_image, text_embed, training=True)

        g_loss = generator_loss(fake_logits)
        d_loss = discriminator_loss(real_logits, fake_logits)

    grad_g = gen_tape.gradient(g_loss, generator.trainable_variables)
    grad_d = disc_tape.gradient(d_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(grad_g, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(grad_d, discriminator.trainable_variables))
    
    return g_loss, d_loss

In [463]:
@tf.function
def test_step(caption, noise, hidden):
    text_embed, hidden = text_encoder(caption, hidden)
    _, fake_image = generator(text_embed, noise)
    return fake_image

Sample Debugging (unused)

In [464]:
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1], 3))
    for idx, image in enumerate(images):
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w, :] = image
    return img

def imsave(images, size, path):
    # getting the pixel values between [0, 1] to save it
    return plt.imsave(path, merge(images, size)*0.5 + 0.5)

def save_images(images, size, image_path):
    return imsave(images, size, image_path)

In [465]:
def sample_generator(caption, batch_size, caption_type='id'):
    if caption_type == 'sentence':
        caption = caption2string(caption)
    caption = np.asarray(caption)
    if caption_type == 'id':
        caption = caption.astype(np.int)
    dataset = tf.data.Dataset.from_tensor_slices(caption)
    dataset = dataset.batch(batch_size)
    return dataset

In [466]:
ni = int(np.ceil(np.sqrt(hparas['BATCH_SIZE'])))
sample_size = hparas['BATCH_SIZE']
sample_seed = np.random.normal(loc=0.0, scale=1.0, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)
sample_sentence = ["the flower shown has yellow anther red pistil and bright red petals."] * int(sample_size/ni) + \
                  ["this flower has petals that are yellow, white and purple and has dark lines"] * int(sample_size/ni) + \
                  ["the petals on this flower are white with a yellow center"] * int(sample_size/ni) + \
                  ["this flower has a lot of small round pink petals."] * int(sample_size/ni) + \
                  ["this flower is orange in color, and has petals that are ruffled and rounded."] * int(sample_size/ni) + \
                  ["the flower has yellow petals and the center of it is brown."] * int(sample_size/ni) + \
                  ["this flower has petals that are blue and white."] * int(sample_size/ni) +\
                  ["these white flowers have petals that start off white in color and end in a white towards the tips."] * int(sample_size/ni)

for i, sent in enumerate(sample_sentence):
    sample_sentence[i] = sent2IdList(sent)
sample_sentence = sample_generator(sample_sentence, hparas['BATCH_SIZE'], caption_type=expSettings.caption_type)

In [467]:
# test the sample dataset
for cap in sample_sentence.take(1):
    print("Caption shape:", cap.numpy().shape)
    emb, _ = text_encoder(cap, text_encoder.initialize_hidden_state())
    print("Caption embeddings:", emb)

Caption shape: (64,)
Caption embeddings: tf.Tensor(
[[-0.01346336  0.06881763 -0.05529514 ... -0.01518281 -0.00633275
   0.01500697]
 [-0.01346336  0.06881763 -0.05529514 ... -0.01518281 -0.00633275
   0.01500697]
 [-0.01346336  0.06881763 -0.05529514 ... -0.01518281 -0.00633275
   0.01500697]
 ...
 [-0.00790838  0.06233827  0.01107207 ... -0.02398296 -0.03450323
  -0.00868433]
 [-0.00790838  0.06233827  0.01107207 ... -0.02398296 -0.03450323
  -0.00868433]
 [-0.00790838  0.06233827  0.01107207 ... -0.02398296 -0.03450323
  -0.00868433]], shape=(64, 512), dtype=float32)


In [468]:
if not os.path.exists('samples/demo'):
    os.makedirs('samples/demo')

Training and testing

In [469]:
def train(dataset, epochs):
    # hidden state of RNN
    hidden = text_encoder.initialize_hidden_state()
    steps_per_epoch = int(hparas['N_SAMPLE']/hparas['BATCH_SIZE'])
    
    lowest_gen_loss = 1e10
    
    for epoch in range(hparas['N_EPOCH']):
        g_total_loss = 0
        d_total_loss = 0
        batchid = 0
        start = time.time()
        imshow = False
        
        for image, caption in dataset:
            batchid += 1
            g_loss, d_loss = train_step(image, caption, hidden, imshow=imshow)
            imshow = False
            g_total_loss += g_loss
            d_total_loss += d_loss
        
        time_tuple = time.localtime()
        time_string = time.strftime("%m/%d/%Y, %H:%M:%S", time_tuple)
            
        print("Epoch {}, gen_loss: {:.4f}, disc_loss: {:.4f}".format(epoch+1,
                                                                     g_total_loss/steps_per_epoch,
                                                                     d_total_loss/steps_per_epoch))
        print('Time for epoch {} is {:.4f} sec'.format(epoch+1, time.time()-start))
        
        # save the model if lower gen loss is achieved
        if g_total_loss < lowest_gen_loss:
            lowest_gen_loss = g_total_loss
            ckptManager.save()
            print('new lowest. saving model.')
        elif g_total_loss < 3 * lowest_gen_loss:
            ckptManager.save()
            print('within save threshold. saving model.')
        else:
            lowest_gen_loss *= 1.02
            
        print('======================================')
        
        # visualization
        if (epoch + 1) % hparas['PRINT_FREQ'] == 0:
            for caption in sample_sentence:
                fake_image = test_step(caption, sample_seed, hidden)
            save_images(fake_image,
                        [ni, ni],
                        'samples/demo/train_{:03d}.jpg'.format(epoch))

In [470]:
train(dataset, hparas['N_EPOCH'])

  "shape. This may consume a large amount of memory." % value)


Epoch 1, gen_loss: 1.8861, disc_loss: 28.5298
Time for epoch 1 is 157.0279 sec
new lowest. saving model.
Epoch 2, gen_loss: 12.8019, disc_loss: 8.6626
Time for epoch 2 is 150.1114 sec
Epoch 3, gen_loss: 10.6726, disc_loss: 9.3638
Time for epoch 3 is 150.2461 sec
Epoch 4, gen_loss: 16.8685, disc_loss: 4.5566
Time for epoch 4 is 151.5259 sec
Epoch 5, gen_loss: 13.9472, disc_loss: 9.9693
Time for epoch 5 is 153.3798 sec
Epoch 6, gen_loss: 6.9875, disc_loss: 13.6581
Time for epoch 6 is 152.1112 sec
Epoch 7, gen_loss: 6.7130, disc_loss: 13.6902
Time for epoch 7 is 143.6160 sec
Epoch 8, gen_loss: 7.2897, disc_loss: 12.8268
Time for epoch 8 is 149.8868 sec
Epoch 9, gen_loss: 6.6670, disc_loss: 13.5848
Time for epoch 9 is 151.8409 sec
Epoch 10, gen_loss: 6.6551, disc_loss: 13.5345
Time for epoch 10 is 145.2812 sec
Epoch 11, gen_loss: 7.8587, disc_loss: 12.0351
Time for epoch 11 is 146.3786 sec
Epoch 12, gen_loss: 9.4297, disc_loss: 10.5063
Time for epoch 12 is 165.8329 sec
Epoch 13, gen_loss: 

Epoch 66, gen_loss: 17.1582, disc_loss: 6.8046
Time for epoch 66 is 139.8840 sec
Epoch 67, gen_loss: 18.4937, disc_loss: 6.7038
Time for epoch 67 is 140.1898 sec
Epoch 68, gen_loss: 17.5797, disc_loss: 7.1567
Time for epoch 68 is 140.0209 sec
Epoch 69, gen_loss: 16.2989, disc_loss: 7.7747
Time for epoch 69 is 139.9279 sec
within save threshold. saving model.
Epoch 70, gen_loss: 19.0292, disc_loss: 7.6593
Time for epoch 70 is 139.9626 sec
Epoch 71, gen_loss: 16.7195, disc_loss: 7.5823
Time for epoch 71 is 140.5164 sec
within save threshold. saving model.
Epoch 72, gen_loss: 17.1161, disc_loss: 7.6399
Time for epoch 72 is 140.2308 sec
within save threshold. saving model.
Epoch 73, gen_loss: 17.6940, disc_loss: 7.4323
Time for epoch 73 is 139.7159 sec
Epoch 74, gen_loss: 16.9063, disc_loss: 7.3340
Time for epoch 74 is 138.1340 sec
within save threshold. saving model.
Epoch 75, gen_loss: 17.1115, disc_loss: 8.3130
Time for epoch 75 is 139.5401 sec
within save threshold. saving model.
Epoch

Epoch 122, gen_loss: 18.3519, disc_loss: 7.6002
Time for epoch 122 is 142.8065 sec
within save threshold. saving model.
Epoch 123, gen_loss: 18.2381, disc_loss: 8.3729
Time for epoch 123 is 143.4180 sec
within save threshold. saving model.
Epoch 124, gen_loss: 19.3381, disc_loss: 7.8335
Time for epoch 124 is 143.0156 sec
within save threshold. saving model.
Epoch 125, gen_loss: 21.1643, disc_loss: 6.5650
Time for epoch 125 is 143.3200 sec
within save threshold. saving model.
Epoch 126, gen_loss: 19.8722, disc_loss: 7.2928
Time for epoch 126 is 142.8948 sec
within save threshold. saving model.
Epoch 127, gen_loss: 21.0080, disc_loss: 7.1238
Time for epoch 127 is 143.1770 sec
within save threshold. saving model.
Epoch 128, gen_loss: 22.7944, disc_loss: 6.7971
Time for epoch 128 is 144.0659 sec
Epoch 129, gen_loss: 20.2790, disc_loss: 6.5093
Time for epoch 129 is 143.4065 sec
within save threshold. saving model.
Epoch 130, gen_loss: 20.5188, disc_loss: 7.7445
Time for epoch 130 is 143.055

within save threshold. saving model.
Epoch 175, gen_loss: 21.0818, disc_loss: 7.4575
Time for epoch 175 is 143.2127 sec
within save threshold. saving model.
Epoch 176, gen_loss: 22.3250, disc_loss: 6.8384
Time for epoch 176 is 143.1688 sec
within save threshold. saving model.
Epoch 177, gen_loss: 22.5193, disc_loss: 7.1329
Time for epoch 177 is 143.2039 sec
within save threshold. saving model.
Epoch 178, gen_loss: 22.9843, disc_loss: 7.7611
Time for epoch 178 is 143.0438 sec
within save threshold. saving model.
Epoch 179, gen_loss: 23.5685, disc_loss: 7.2938
Time for epoch 179 is 143.3026 sec
Epoch 180, gen_loss: 21.0205, disc_loss: 8.7744
Time for epoch 180 is 143.8735 sec
within save threshold. saving model.
Epoch 181, gen_loss: 22.1262, disc_loss: 8.5132
Time for epoch 181 is 143.4687 sec
within save threshold. saving model.
Epoch 182, gen_loss: 24.6318, disc_loss: 7.2340
Time for epoch 182 is 143.4273 sec
Epoch 183, gen_loss: 20.9348, disc_loss: 8.3807
Time for epoch 183 is 143.566

KeyboardInterrupt: 

---

In [471]:
def test_caption2string(ls):
    return " ".join([id2word_dict[idx] for idx in ls]).strip().split(' <PAD>')[0]

def testing_data_generator(caption, index, caption_type='id'):
    if caption_type == 'id':
        caption = tf.cast(caption, tf.float32)
    return caption, index

def testing_dataset_generator(batch_size, data_generator, caption_type='id'):
    data = pd.read_pickle('./dataset/testData.pkl')
    
    if caption_type == 'sentence':
        data['Captions_string'] = data['Captions'].apply(test_caption2string)
        captions = data['Captions_string'].values
    elif caption_type == 'id':
        captions = data['Captions'].values
        
    caption = []
    for i in range(len(captions)):
        caption.append(captions[i])
    caption = np.asarray(caption)
    
    if caption_type == 'id':
        caption = caption.astype(np.int)
        
    datagen_func = lambda cap, img: data_generator(cap, img, caption_type=caption_type)
        
    index = data['ID'].values
    index = np.asarray(index)
    
    dataset = tf.data.Dataset.from_tensor_slices((caption, index))
    dataset = dataset.map(datagen_func, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.repeat().batch(batch_size)
    
    return dataset

In [472]:
testing_dataset = testing_dataset_generator(hparas['BATCH_SIZE'], testing_data_generator, caption_type=expSettings.caption_type)

In [473]:
# test the testing dataset
for cap, img in testing_dataset.take(1):
    print("Image shape:", img.numpy().shape)
    print("Caption shape:", cap.numpy().shape)

Image shape: (64,)
Caption shape: (64,)


In [474]:
data = pd.read_pickle('./dataset/testData.pkl')
captions = data['Captions'].values

NUM_TEST = len(captions)
EPOCH_TEST = int(NUM_TEST / hparas['BATCH_SIZE'])

In [475]:
if not os.path.exists('./inference/demo'):
    os.makedirs('./inference/demo')

In [476]:
def inference(dataset):
    hidden = text_encoder.initialize_hidden_state()
    sample_size = hparas['BATCH_SIZE']
    sample_seed = np.random.normal(loc=0.0, scale=0.1, size=(sample_size, hparas['Z_DIM'])).astype(np.float32)
    print(sample_seed[0:3, :])
    
    step = 0
    start = time.time()
    for captions, idx in dataset:
        if step > EPOCH_TEST:
            break
        
        fake_image = test_step(captions, sample_seed, hidden)
        step += 1
        for i in range(hparas['BATCH_SIZE']):
            plt.imsave('./inference/demo/inference_{:04d}.jpg'.format(idx[i]), fake_image[i].numpy()*0.5 + 0.5)
            
            if i == 0 and step == 1: 
                #print(captions)
                text_embed_t, hidden_t = text_encoder(captions, hidden)
                #print(text_embed_t)
                print(fake_image[0:1, 0:5, 0:5, :])
                img_logits, _, debug = generator(text_embed_t, sample_seed, debug_output=True)
                print(debug[0][0:3, 0:5, 0:5, 0:5])
                print(debug[1][0:3, 0:5, 0:5, 0:5])
                print(debug[2][0:3, 0:5, 0:5, 0:5])
                print(debug[3][0:3, 0:5, 0:5, 0:5])
                print(debug[4][0:3, 0:5, 0:5, 0:5])
                print(img_logits[0:3, 0:5, 0:5, :])
                pred_logit, pred = discriminator(fake_image, text_embed_t)
                print(pred_logit)
                
            
    print('Time for inference is {:.4f} sec'.format(time.time()-start))

In [477]:
#checkpoint.restore(checkpoint_dir + f'/ckpt-50')
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
    print(f"Restoring from {latest_checkpoint}")
    checkpoint.restore(latest_checkpoint)
else:
    print("No checkpoint found.")

Restoring from ./checkpoints/demo\ckpt-129


In [478]:
inference(testing_dataset)

[[ 0.10586753  0.1683734  -0.04186799 ... -0.08489064 -0.0329618
   0.03781696]
 [ 0.03922252  0.03296076  0.13910006 ...  0.13605477 -0.0755486
   0.09131289]
 [-0.09989461 -0.01007991  0.012931   ... -0.0011589  -0.02907026
  -0.09748437]]
tf.Tensor(
[[[[ 0.21720769  0.47280306  0.12604918]
   [ 0.17457196  0.40070975  0.07731792]
   [ 0.14772408  0.28354514  0.01530687]
   [ 0.10258732  0.34258884  0.03953101]
   [ 0.42752028  0.6867732   0.5655326 ]]

  [[-0.00210296  0.2229808  -0.01537495]
   [ 0.11066125  0.2839741   0.01735989]
   [ 0.17913605  0.293927    0.12320141]
   [ 0.37599295  0.5424648   0.36960796]
   [ 0.42181563  0.57904196  0.5068701 ]]

  [[ 0.28094682  0.5502275   0.19273959]
   [ 0.269157    0.43758902  0.16704144]
   [ 0.26127547  0.35547554  0.26871258]
   [ 0.5230407   0.7386627   0.4779848 ]
   [ 0.5606745   0.76229715  0.59335446]]

  [[ 0.23168376  0.3678933   0.11534017]
   [ 0.33075872  0.40716508  0.32358813]
   [ 0.41750866  0.49542183  0.4464383 ]
   

Time for inference is 3.3475 sec


In [479]:
for var in generator.starter.variables:
    print(var.name, var.numpy())

generator_11/dense_block_11/dense_68/kernel:0 [[ 1.04421824e-01  1.62826896e-01 -6.50692955e-02 ... -1.60113964e-02
   8.78449455e-02 -6.49436191e-02]
 [ 4.35370542e-02  6.03674501e-02 -2.23204881e-01 ... -4.61561568e-02
   3.26780863e-02 -1.17234230e-01]
 [-1.49605691e-03 -1.22202225e-01 -2.64721271e-02 ... -1.15404315e-01
  -5.03560603e-02 -3.60469073e-02]
 ...
 [-3.10842823e-02  1.68469623e-02  5.76655753e-02 ...  6.73484709e-03
   1.20996928e-03  8.23529437e-03]
 [-1.79251321e-02  2.56902073e-04 -3.41315903e-02 ...  1.04296429e-04
   6.83364645e-03  3.86595144e-03]
 [-2.72888411e-02 -3.79072465e-02 -3.89734954e-02 ... -5.98441251e-03
   6.52544051e-02  3.57083306e-02]]
generator_11/dense_block_11/dense_68/bias:0 [-2.8460745e-05 -1.7617382e-04  9.6550197e-05 ... -3.7512651e-05
 -8.0961553e-07  1.3375393e-04]
generator_11/dense_block_11/batch_normalization_203/gamma:0 [0.49476177 0.762858   0.6209676  ... 0.7114634  0.5913795  1.4854782 ]
generator_11/dense_block_11/batch_normalizati

In [480]:
for var in discriminator.variables:
    print(var.name, var.numpy())

discriminator_11/conv_block_136/conv2d_136/kernel:0 [[[[-5.93960881e-02 -2.95454621e-01  1.63213491e-01 ...  1.34212049e-02
     1.85690090e-01 -3.65044832e-01]
   [-2.69940883e-01 -4.83936846e-01  2.93261737e-01 ...  2.50138283e-01
    -2.58412927e-01  4.39943314e-01]
   [-2.37089887e-01  3.11541796e-01 -7.86949843e-02 ...  2.10212722e-01
     5.86252570e-01 -1.32484674e-01]]

  [[ 2.93688059e-01  3.39040816e-01  1.81222372e-02 ...  5.93196638e-02
     1.04732387e-01 -6.72897771e-02]
   [-3.89340013e-01 -2.27928653e-01 -3.93214785e-02 ... -3.32181424e-01
    -1.40229151e-01 -2.60800600e-01]
   [ 3.64359558e-01 -1.81910336e-01 -2.43829787e-01 ...  1.48140311e-01
    -3.28003377e-01  2.57840343e-02]]

  [[ 5.18545270e-01 -1.35839179e-01 -8.03554803e-02 ... -3.15286934e-01
     2.51228750e-01 -7.69500583e-02]
   [ 7.54195452e-02 -3.34208727e-01 -1.59700811e-01 ...  4.27713633e-01
     1.70536786e-01  1.00157253e-01]
   [ 1.28874974e-02  3.55607897e-01  3.75459135e-01 ...  1.52615532e-01


In [481]:
%cd ./testing
!python inception_score.py ../inference/demo output.csv 21
%cd ../

C:\Users\User\Courses\24aut_deep_learning\24aut-deep-learning\deep-learning-comp3\testing
1 Physical GPUs, 1 Logical GPUs
--------------Evaluation Success-----------------
C:\Users\User\Courses\24aut_deep_learning\24aut-deep-learning\deep-learning-comp3


In [482]:
text_encoder.summary()

Model: "text_encoder_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_11 (Embedding)    multiple                  0 (unused)
                                                                 
 gru_11 (GRU)                multiple                  0 (unused)
                                                                 
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________


In [483]:
generator.summary()

Model: "generator_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_22 (Flatten)        multiple                  0 (unused)
                                                                 
 dense_66 (Dense)            multiple                  0 (unused)
                                                                 
 dense_67 (Dense)            multiple                  0 (unused)
                                                                 
 dense_block_11 (dense_block  multiple                 2107392   
 )                                                               
                                                                 
 deconv_block_58 (deconv_blo  multiple                 4196864   
 ck)                                                             
                                                                 
 deconv_block_59 (deconv_blo  multiple                

In [484]:
discriminator.summary()

Model: "discriminator_11"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv_block_136 (conv_block)  multiple                 16384     
                                                                 
 conv_block_137 (conv_block)  multiple                 264704    
                                                                 
 dropout_28 (Dropout)        multiple                  0         
                                                                 
 conv_block_138 (conv_block)  multiple                 1180928   
                                                                 
 conv_block_139 (conv_block)  multiple                 66816     
                                                                 
 dropout_29 (Dropout)        multiple                  0         
                                                                 
 conv_block_140 (conv_block)  multiple            