**WARNING**: This notebook does not work anymore after upgrading the environment to tf>2.16. The saved models are not keras models anymore, hence, there is no model.layers property to count the layers.

**DECISION**: Calculating the number of layers in this notebook/project is very unlikely to be necessary or worth the effort of rolling back the tf version and this codebase. Hence, we're leaving it as is. Counting can be done manually in the multi-domain repository, or using model.summary() there. To measure the inference time, we might want to use this notebook -- but we'll probably not have a hard time to adapt it to the new models/environment.

# Generating images with the Proposed Models

This notebook measures the time a model takes to generate an image. The time 
is selected as the quickest among 100 runs. It is not the average, as we
want to know how fast inference takes in an environment that is dedicated
to the task.

## Loading the Models

In [None]:
import tensorflow as tf
from ModelProxy import Pix2PixModelProxy, StarGANModelProxy, CollaGANModelProxy

model_loaders = {
    'pix2pix': lambda: tf.keras.models.load_model('models/pix2pix/front-to-right', compile=False),
    'stargan': lambda: tf.keras.models.load_model('models/stargan', compile=False),
    'collagan': lambda: tf.keras.models.load_model('models/collagan', compile=False)
}
# model_loaders = {
#     'pix2pix': lambda: Pix2PixModelProxy('models/pix2pix'),
#     'stargan': lambda: StarGANModelProxy('models/stargan'),
#     'collagan': lambda: CollaGANModelProxy('models/collagan', is_legacy_tf_saved_model=True)
# }

pix2pix = model_loaders['pix2pix']()
stargan = model_loaders['stargan']()
collagan = model_loaders['collagan']()

In [26]:
# from collections import deque
#
# def count_layers(module, debug=False):
#     return len(module.graph.get_operations())
#
# count_layers(pix2pix._select_model(2,3).model.signatures['serving_default'], debug=True)


40

## Counting Layers

In [15]:
def count_layers_from_keras_model(model):
    return len(model.layers)

def count_layers(model):
    visited = {id(model)}
    unique_layers = set()
    stack = [model]

    while stack:
        obj = stack.pop()

        # Count non-container layers (Leaf layers like Dense, Conv2D, etc.)
        print("type(obj).__name__:", type(obj).__name__)
        if isinstance(obj, tf.keras.layers.Layer) and not isinstance(obj, tf.keras.Model):
            unique_layers.add(id(obj))

        # Traverse trackable children (internal API usage)
        if hasattr(obj, '_trackable_children'):
            children = obj._trackable_children().values()
            print(f"-- {len(children)} children")
            for child in children:
                if id(child) not in visited:
                    visited.add(id(child))
                    stack.append(child)


    print(f"Visited {len(visited)} objects in the model.")
    return len(unique_layers)

number_of_layers = [
    ('pix2pix', count_layers(pix2pix._select_model(2, 3).model)),
    ('stargan', count_layers(stargan.model)),
    ('collagan', count_layers(collagan.model))
]

header_format = "{:<10} | {:>8}"
row_format = "{:<10} | {:>8d}"
print(header_format.format('Model', 'Layers'))
for model, layers in number_of_layers:
    print(row_format.format(model, layers))



type(obj).__name__: _UserObject
-- 42 children
type(obj).__name__: _DictWrapper
-- 0 children
type(obj).__name__: _SignatureMap
-- 0 children
type(obj).__name__: _WrapperFunction
-- 0 children
type(obj).__name__: RestoredFunction
-- 0 children
type(obj).__name__: RestoredFunction
-- 0 children
type(obj).__name__: _UserObject
-- 13 children
type(obj).__name__: _DictWrapper
-- 0 children
type(obj).__name__: _DictWrapper
-- 0 children
type(obj).__name__: ListWrapper
-- 0 children
type(obj).__name__: ListWrapper
-- 0 children
type(obj).__name__: ListWrapper
-- 20 children
type(obj).__name__: ListWrapper
-- 0 children
type(obj).__name__: ListWrapper
-- 0 children
type(obj).__name__: ListWrapper
-- 36 children
type(obj).__name__: UninitializedVariable
-- 0 children
type(obj).__name__: UninitializedVariable
-- 0 children
type(obj).__name__: UninitializedVariable
-- 0 children
type(obj).__name__: UninitializedVariable
-- 0 children
type(obj).__name__: UninitializedVariable
-- 0 children
type(o

## Loading the Data

In [3]:
def load_image(dataset, domain, index):
    path = f'datasets/{dataset}/test/{domain}/{index}.png'
    image = tf.io.read_file(path)
    image = tf.io.decode_png(image, channels=4)
    image = tf.cast(image, "float32")
    image = (image / 127.5) - 1.
    return image

def load_character(index):
    return {
    'back': load_image('rpg-maker-xp', '0-back', index),
    'left': load_image('rpg-maker-xp', '1-left', index),
    'front': load_image('rpg-maker-xp', '2-front', index),
    'right': load_image('rpg-maker-xp', '3-right', index)
}


## Generating Images

In [4]:
from ipywidgets import interact
from matplotlib import pyplot as plt

In [5]:
@interact(index=(0,43))
def generate_pix2pix(index=0):
    character = load_character(index)
    
    source_image = character['front']
    
    target_image = pix2pix(source_image[tf.newaxis, ...], training=True)
    target_image = tf.squeeze(target_image)
    
    fig = plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(source_image * 0.5 + 0.5)
    plt.title(f'Source Image: {index}')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(target_image * 0.5 + 0.5)
    plt.title(f'Generated: {index}')
    plt.axis('off')
    plt.show()
    
    fig.tight_layout()
        


interactive(children=(IntSlider(value=0, description='index', max=43), Output()), _dom_classes=('widget-intera…

In [6]:
@interact(index=(0,43))
def generate_stargan(index=0):
    character = load_character(index)
    
    source_image = character['front']
    source_domain = 2
    target_domain = 3
    
    target_image = stargan([
        source_image[tf.newaxis, ...],
        tf.constant([[target_domain]]),
        tf.constant([[source_domain]]),
    ])
    target_image = tf.squeeze(target_image)
    
    fig = plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(source_image * 0.5 + 0.5)
    plt.title(f'Source Image: {index}')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(target_image * 0.5 + 0.5)
    plt.title(f'Generated: {index}')
    plt.axis('off')
    plt.show()
    
    fig.tight_layout()



interactive(children=(IntSlider(value=0, description='index', max=43), Output()), _dom_classes=('widget-intera…

In [7]:
@interact(index=(0,43))
def generate_collagan(index=0):
    character = load_character(index)
    
    source_images = tf.stack([character['back'], character['left'], character['front'], tf.zeros_like(character['right'])])
    target_domain = 3

    target_image = collagan([source_images[tf.newaxis, ...], tf.constant([[target_domain]])], training=True)    
    target_image = tf.squeeze(target_image)
    
    fig = plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.title(f'Source Images: {index}')
    plt.axis('off')
    
    for i, idx in enumerate([1, 2, 5, 6]):
        plt.subplot(2, 4, idx)
        plt.imshow(source_images[i] * 0.5 + 0.5, interpolation='nearest')
        plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(target_image * 0.5 + 0.5, interpolation='nearest')
    plt.title(f'Generated: {index}')
    plt.axis('off')
    plt.show()

    fig.tight_layout()


interactive(children=(IntSlider(value=0, description='index', max=43), Output()), _dom_classes=('widget-intera…