# Settings

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%env TF_KERAS = 1
import os
sep_local = os.path.sep
import sys
# sys.path.append('..' + sep_local + '..' + sep_local +'..' + sep_local + '..' + sep_local + '..'+ sep_local + '..') # For Windows import
os.chdir('..' + sep_local + '..' + sep_local +'..' + sep_local + '..' + sep_local + '..' + sep_local + '..') # For Linux import
print(sep_local)
print(os.getcwd())

In [None]:
import tensorflow as tf
print(tf.__version__)

# Dataset loading

In [None]:
dataset_name='atari_pacman'

In [None]:
images_dir = '/home/azeghost/datasets/.mspacman/atari_v1/screens/mspacman' #Linux
#images_dir = 'C:\\projects\\pokemon\DS06\\'
validation_percentage = 25
valid_format = 'png'

In [None]:
from training.generators.from_images.file_image_generator import create_image_lists, get_generators

In [None]:
imgs_list = create_image_lists(
    image_dir=images_dir, 
    validation_pct=validation_percentage, 
    valid_imgae_formats=valid_format,
    verbose=0 
)

In [None]:
scale=1
image_size=(160//scale, 210//scale, 3)

batch_size = 10
EPIS_LEN = 10
EPIS_SHIFT = 5

inputs_shape= image_size
latents_dim = 30
intermediate_dim = 30

In [None]:
training_generator, testing_generator = get_generators(
    images_list=imgs_list, 
    image_dir=images_dir, 
    image_size=image_size, 
    batch_size=batch_size, 
    class_mode='episode_flat', 
    episode_len=EPIS_LEN, 
    episode_shift=EPIS_SHIFT
)

In [None]:
import tensorflow as tf

In [None]:
train_ds = tf.data.Dataset.from_generator(
    lambda: training_generator, 
    output_types=(tf.float32, tf.float32) ,
    output_shapes=(tf.TensorShape((batch_size* EPIS_LEN, ) + image_size), 
                   tf.TensorShape((batch_size* EPIS_LEN, ) + image_size)
                  )
)

test_ds = tf.data.Dataset.from_generator(
    lambda: testing_generator,     
    output_types=(tf.float32, tf.float32) ,
    output_shapes=(tf.TensorShape((batch_size* EPIS_LEN, ) + image_size), 
                   tf.TensorShape((batch_size* EPIS_LEN, ) + image_size)
                  )
)

In [None]:
_instance_scale=1.0
for data in train_ds:
    _instance_scale = float(data[0].numpy().max())
    break

In [None]:
_instance_scale = 1.0

In [None]:
import numpy as np
from collections.abc import Iterable
if isinstance(inputs_shape, Iterable):
    _outputs_shape = np.prod(inputs_shape)

In [None]:
inputs_shape

# Model's Layers definition

In [None]:
units=30
#c=(160//4, 210//4 + 2, intermediate_dim//2)
c=(image_size[0]//4, image_size[1]//6, intermediate_dim//2) # now 4x and 6x smaller since kernels are 2 and 3 
enc_lays = [
    tf.keras.layers.Conv2D(filters=units, kernel_size=3, strides=(2, 2), activation='relu'),
    tf.keras.layers.Conv2D(filters=units//5, kernel_size=3, strides=(2, 3), activation='relu'),
    tf.keras.layers.Flatten(),
    # No activation
    tf.keras.layers.Dense(latents_dim)
]

dec_lays = [
    tf.keras.layers.Dense(units=np.product(c), activation=tf.nn.relu),
    tf.keras.layers.Reshape(target_shape=c),
    tf.keras.layers.Conv2DTranspose(filters=units//5, kernel_size=3, strides=(2, 3), padding="SAME", activation='relu'),
    tf.keras.layers.Conv2DTranspose(filters=units, kernel_size=3, strides=(2, 2), padding="SAME", activation='relu'),
    
    # No activation
    tf.keras.layers.Conv2DTranspose(filters=3, kernel_size=3, strides=(1, 1), padding="SAME")
]

# Model definition

In [None]:
model_name = dataset_name+'Conv_LT_IAAE'
experiments_dir='experiments'+sep_local+model_name

In [None]:
from training.adversarial_basic.inference_adversarial.transformative.AAE import AAE as AE

In [None]:
inputs_shape=image_size

In [None]:
variables_params = \
[
    {
        'name': 'inference', 
        'inputs_shape':inputs_shape,
        'outputs_shape':latents_dim,
        'layers': enc_lays
    }

    ,
    
        {
        'name': 'generative', 
        'inputs_shape':latents_dim,
        'outputs_shape':inputs_shape,
        'layers':dec_lays
    }
]

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist

In [None]:
_restore = os.path.join(experiments_dir, 'var_save_dir')

In [None]:
create_if_not_exist(_restore)
_restore

In [None]:
#to restore trained model, set filepath=_restore

In [None]:
from statistical.basic_adversarial_losses import \
    create_inference_discriminator_real_losses, \
    create_inference_discriminator_fake_losses, \
    create_inference_generator_fake_losses

inference_discriminator_losses = {
    'inference_discriminator_real_outputs': create_inference_discriminator_real_losses,
    'inference_discriminator_fake_outputs': create_inference_discriminator_fake_losses,
    'inference_generator_fake_outputs': create_inference_generator_fake_losses,
}

In [None]:
ae = AE( 
    name=model_name,
    latents_dim=latents_dim,
    batch_size=batch_size * EPIS_LEN,
    episode_len= 1, 
    variables_params=variables_params, 
    filepath=None
    )

In [None]:
discr2gen_rate = 0.001
gen2trad_rate = 0.1

ae.compile(
    adversarial_losses=inference_discriminator_losses,
    adversarial_weights={'generator_weight': gen2trad_rate, 'discriminator_weight': discr2gen_rate}
)

In [None]:

from training.callbacks.sample_generation import SampleGeneration
from training.callbacks.save_model import ModelSaver

In [None]:
es = tf.keras.callbacks.EarlyStopping(
    monitor='loss', 
    min_delta=1e-12, 
    patience=12, 
    verbose=1, 
    restore_best_weights=False
)

In [None]:
ms = ModelSaver(filepath=_restore)

In [None]:
csv_dir = os.path.join(experiments_dir, 'csv_dir')
create_if_not_exist(csv_dir)
csv_dir = os.path.join(csv_dir, ae.name+'.csv')
csv_log = tf.keras.callbacks.CSVLogger(csv_dir, append=True)
csv_dir

In [None]:
image_gen_dir = os.path.join(experiments_dir, 'image_gen_dir')
create_if_not_exist(image_gen_dir)

In [None]:
sg = SampleGeneration(latents_shape=latents_dim, filepath=image_gen_dir, gen_freq=5, save_img=True, gray_plot=False)

In [None]:
import numpy as np

In [None]:
ae.fit(
    x=train_ds,
    input_kw=None,
    steps_per_epoch=int(1e4),
    epochs=int(1e6), 
    verbose=2,
    callbacks=[ es, ms, csv_log, sg],
    workers=-1,
    use_multiprocessing=True,
    validation_data=test_ds,
    validation_steps=int(1e4)
)

# Model Evaluation

## inception_score

In [None]:
from evaluation.generativity_metrics.inception_metrics import inception_score

In [None]:
is_mean, is_sigma = inception_score(ae, tolerance_threshold=1e-6, max_iteration=200)
print(f'inception_score mean: {is_mean}, sigma: {is_sigma}')

## Frechet_inception_distance

In [None]:
from evaluation.generativity_metrics.inception_metrics import frechet_inception_distance

In [None]:
fis_score = frechet_inception_distance(ae, training_generator, tolerance_threshold=1e-6, max_iteration=10, batch_size=32)
print(f'frechet inception distance: {fis_score}')

## perceptual_path_length_score

In [None]:
from evaluation.generativity_metrics.perceptual_path_length import perceptual_path_length_score

In [None]:
ppl_mean_score = perceptual_path_length_score(ae, training_generator, tolerance_threshold=1e-6, max_iteration=200, batch_size=32)
print(f'perceptual path length score: {ppl_mean_score}')

## precision score

In [None]:
from evaluation.generativity_metrics.precision_recall import precision_score

In [None]:
_precision_score = precision_score(ae, training_generator, tolerance_threshold=1e-6, max_iteration=200)
print(f'precision score: {_precision_score}')

## recall score

In [None]:
from evaluation.generativity_metrics.precision_recall import recall_score

In [None]:
_recall_score = recall_score(ae, training_generator, tolerance_threshold=1e-6, max_iteration=200)
print(f'recall score: {_recall_score}')

# Image Generation

## image reconstruction

### Training dataset

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from training.generators.image_generation_testing import reconstruct_from_a_batch

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'reconstruct_training_images_like_a_batch_dir')
create_if_not_exist(save_dir)

reconstruct_from_a_batch(ae, training_generator, save_dir)

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'reconstruct_testing_images_like_a_batch_dir')
create_if_not_exist(save_dir)

reconstruct_from_a_batch(ae, testing_generator, save_dir)

## with Randomness

In [None]:
from training.generators.image_generation_testing import generate_images_like_a_batch

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'generate_training_images_like_a_batch_dir')
create_if_not_exist(save_dir)

generate_images_like_a_batch(ae, training_generator, save_dir)

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'generate_testing_images_like_a_batch_dir')
create_if_not_exist(save_dir)

generate_images_like_a_batch(ae, testing_generator, save_dir)

### Complete Randomness

In [None]:
from training.generators.image_generation_testing import generate_images_randomly

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'random_synthetic_dir')
create_if_not_exist(save_dir)

generate_images_randomly(ae, testing_generator, save_dir)

### Stacked inputs outputs and predictions 

In [None]:
from training.generators.image_generation_testing import predict_from_a_batch

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'predictions')
create_if_not_exist(save_dir)

predict_from_a_batch(ae, testing_generator, save_dir)

