# Generating images with the Proposed Models

This notebook measures the time a model takes to generate an image. The time 
is selected as the quickest among 100 runs. It is not the average, as we
want to know how fast inference takes in an environment that is dedicated
to the task.

## Loading the Models

In [1]:
import tensorflow as tf

model_loaders = {
    'pix2pix': lambda: tf.keras.models.load_model('models/pix2pix/front-to-right', compile=False),
    'stargan': lambda: tf.keras.models.load_model('models/stargan', compile=False),
    'collagan': lambda: tf.keras.models.load_model('models/collagan', compile=False)
}

pix2pix = model_loaders['pix2pix']()
stargan = model_loaders['stargan']()
collagan = model_loaders['collagan']()


## Counting Layers

In [4]:
def count_layers(model):
    return len(model.layers)

number_of_layers = [
    ('pix2pix', count_layers(pix2pix)),
    ('stargan', count_layers(stargan)),
    ('collagan', count_layers(collagan))
]

header_format = "{:<10} | {:>8}"
row_format = "{:<10} | {:>8d}"
print(header_format.format('Model', 'Layers'))
for model, layers in number_of_layers:
    print(row_format.format(model, layers))



## Loading the Data

In [7]:
def load_image(dataset, domain, index):
    path = f'datasets/{dataset}/test/{domain}/{index}.png'
    image = tf.io.read_file(path)
    image = tf.io.decode_png(image, channels=4)
    image = tf.cast(image, "float32")
    image = (image / 127.5) - 1.
    return image

def load_character(index):
    return {
    'back': load_image('rpg-maker-xp', '0-back', index),
    'left': load_image('rpg-maker-xp', '1-left', index),
    'front': load_image('rpg-maker-xp', '2-front', index),
    'right': load_image('rpg-maker-xp', '3-right', index)
}


## Generating Images

In [6]:
from ipywidgets import interact
from matplotlib import pyplot as plt

In [8]:
@interact(index=(0,43))
def generate_pix2pix(index=0):
    character = load_character(index)
    
    source_image = character['front']
    
    target_image = pix2pix(source_image[tf.newaxis, ...], training=True)
    target_image = tf.squeeze(target_image)
    
    fig = plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(source_image * 0.5 + 0.5)
    plt.title(f'Source Image: {index}')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(target_image * 0.5 + 0.5)
    plt.title(f'Generated: {index}')
    plt.axis('off')
    plt.show()
    
    fig.tight_layout()
        


In [9]:
@interact(index=(0,43))
def generate_stargan(index=0):
    character = load_character(index)
    
    source_image = character['front']
    source_domain = 2
    target_domain = 3
    
    target_image = stargan([
        source_image[tf.newaxis, ...],
        tf.constant([[target_domain]]),
        tf.constant([[source_domain]]),
    ])
    target_image = tf.squeeze(target_image)
    
    fig = plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(source_image * 0.5 + 0.5)
    plt.title(f'Source Image: {index}')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(target_image * 0.5 + 0.5)
    plt.title(f'Generated: {index}')
    plt.axis('off')
    plt.show()
    
    fig.tight_layout()



In [10]:
@interact(index=(0,43))
def generate_collagan(index=0):
    character = load_character(index)
    
    source_images = tf.stack([character['back'], character['left'], character['front'], tf.zeros_like(character['right'])])
    target_domain = 3

    target_image = collagan([source_images[tf.newaxis, ...], tf.constant([[target_domain]])], training=True)    
    target_image = tf.squeeze(target_image)
    
    fig = plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.title(f'Source Images: {index}')
    plt.axis('off')
    
    for i, idx in enumerate([1, 2, 5, 6]):
        plt.subplot(2, 4, idx)
        plt.imshow(source_images[i] * 0.5 + 0.5, interpolation='nearest')
        plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(target_image * 0.5 + 0.5, interpolation='nearest')
    plt.title(f'Generated: {index}')
    plt.axis('off')
    plt.show()

    fig.tight_layout()
