In [1]:
import os

import tensorflow as tf
import numpy as np
import glob
import argparse
from classes.PGGAN import PGGAN
from utils.callbacks import WandbImagesPGGAN
import wandb
import tensorflow.keras as keras
from os.path import join as opj
from wandb.keras import WandbCallback


In [2]:

wandb.login()

checkpoint_path= "models/PGGAN_celebA"
config={"dataset":"celebA", "type":"PG-GAN"}

wandb.init(project="TorVergataExperiment-Generative",config=config)

[34m[1mwandb[0m: Currently logged in as: [33mmatteoferrante[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.7 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [None]:
BS_list = [256,128,64,32]

BS=BS_list[0]

In [None]:

NOISE_DIM = 128
# Set the number of batches, epochs and steps for trainining.
# Look 800k images(16x50x1000) per each lavel
EPOCHS_PER_RES = 16


## Dataloaders

In [None]:
def load_images(imagePath):
    # read the image from disk, decode it, resize it, and scale the
    # pixels intensities to the range [0, 1]
    image = tf.io.read_file(imagePath)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (128, 128)) / 255.0

    #eventually load other information like attributes here
    
    # return the image and the extra info
    
    
    return image

In [None]:
## INIT

def resize(img,target_size=(4,4)):
    return tf.image.resize(img,target_size)

In [None]:
print("[INFO] loading image paths...")
imagePaths = list(paths.list_images(images_dir))


train_len=int(0.8*len(imagePaths))
val_len=int(0.1*len(imagePaths))
test_len=int(0.1*len(imagePaths))

train_imgs=imagePaths[:train_len]                                #      80% for training
val_imgs=imagePaths[train_len:train_len+val_len]                 #      10% for validation
test_imgs=imagePaths[train_len+val_len:]                         #      10% for testing

print(f"[TRAINING]\t {len(train_imgs)}\n[VALIDATION]\t {len(val_imgs)}\n[TEST]\t\t {len(test_imgs)}")

In [None]:
#TRAINING 

train_dataset = tf.data.Dataset.from_tensor_slices(train_imgs)
train_dataset = (train_dataset
    .shuffle(1024)
    .map(resize)
    .cache()
    .repeat()
    .batch(BS)
    .prefetch(AUTOTUNE)
)

ts=len(train_imgs)//BS

##VALIDATION

val_dataset = tf.data.Dataset.from_tensor_slices(val_imgs)
val_dataset = (val_dataset
    .shuffle(1024)
    .map(resize)
    .cache()
    .repeat()
    .batch(BS)
    .prefetch(AUTOTUNE)
)

vs=len(val_imgs)//BS

## TEST

test_dataset = tf.data.Dataset.from_tensor_slices(test_imgs)
test_dataset = (test_dataset
    .shuffle(1024)
    .map(resize)
    .cache()
    .batch(BS)
    .prefetch(AUTOTUNE)
)

In [None]:
# Instantiate the optimizer for both networks
# learning_rate will be equalized per each layers by the WeightScaling scheme
generator_optimizer = keras.optimizers.Adam(learning_rate=0.001, beta_1=0.0, beta_2=0.99, epsilon=1e-8)
discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.001, beta_1=0.0, beta_2=0.99, epsilon=1e-8)

pgan = PGGAN(
    latent_dim = NOISE_DIM,
    d_steps = 1,
)

callbacks=[WandbImagesPGGAN(),WandbCallback()]

pgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
)

os.makedirs(checkpoint_path,exist_ok=True)


In [None]:
os.makedirs(checkpoint_path,exist_ok=True)
# Start training the initial generator and discriminator
pgan.fit(train_dataset, steps_per_epoch = ts, epochs = EPOCHS_PER_RES, callbacks=callbacks)
pgan.save_weights(opj(checkpoint_path, f"checkpoint_path_ndepth_0_weights_celebA.h5"))

tf.keras.utils.plot_model(pgan.generator, to_file=opj(checkpoint_path,f'generator_{pgan.n_depth}.png'), show_shapes=True)
tf.keras.utils.plot_model(pgan.discriminator, to_file=opj(checkpoint_path,f'discriminator_{pgan.n_depth}.png'), show_shapes=True)


In [None]:

# Train faded-in / stabilized generators and discriminators
for n_depth in range(1, 6):



    print(f"[INFO] Fading phase for {n_depth}")
    # Set current level(depth)
    pgan.n_depth = n_depth

    new_dim=2**(n_depth)*4
    new_dim=(new_dim,new_dim)

    ##dataset redefinition
    BS=BS_list[n_depth]
    ts = len(x_train) // BS
    train_dataset = tf.data.Dataset.from_tensor_slices(x_train)

    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(BS).repeat().map(lambda  x: resize(x,new_dim))

    #enlarge network

    pgan.fade_in_generator()
    pgan.fade_in_discriminator()

    # Draw fade in generator and discriminator
    tf.keras.utils.plot_model(pgan.generator, to_file=opj(checkpoint_path,f'generator_{pgan.n_depth}.png'), show_shapes=True)
    tf.keras.utils.plot_model(pgan.discriminator, to_file=opj(checkpoint_path,f'discriminator_{pgan.n_depth}.png'), show_shapes=True)

    pgan.compile(
      d_optimizer=discriminator_optimizer,
      g_optimizer=generator_optimizer,
    )
    # Train fade in generator and discriminator
    pgan.fit(train_dataset, steps_per_epoch=ts, epochs=EPOCHS_PER_RES, callbacks=callbacks)

    pgan.generator.save_weights(opj(checkpoint_path, f"generator_ndepth_{n_depth}_weights_cifar.h5"))
    pgan.discriminator.save_weights(opj(checkpoint_path, f"discriminator_ndepth_{n_depth}_weights_cifar.h5"))

    try:
        pgan.save_weights(opj(checkpoint_path, f"checkpoint_path_ndepth_{n_depth}_weights_cifar.ckpt"),save_format="tf")
    except:
        print("[WARNING] Could not save weights!")


    print(f"[INFO] Stabilizing phase for {n_depth}")
    pgan.stabilize_generator()
    pgan.stabilize_discriminator()

    # Draw fade in generator and discriminator
    tf.keras.utils.plot_model(pgan.generator, to_file=opj(checkpoint_path,f'generator_{pgan.n_depth}_stabilized.png'), show_shapes=True)
    tf.keras.utils.plot_model(pgan.discriminator, to_file=opj(checkpoint_path,f'discriminator_{pgan.n_depth}_stabilized.png'), show_shapes=True)

    pgan.compile(d_optimizer=discriminator_optimizer,g_optimizer=generator_optimizer,)
    # Train stabilized generator and discriminator
    pgan.fit(train_dataset, steps_per_epoch = ts, epochs = EPOCHS_PER_RES, callbacks=callbacks)
    pgan.generator.save_weights(opj(checkpoint_path, f"generator_stabilized_ndepth_{n_depth}_weights_cifar.h5"))
    pgan.discriminator.save_weights(opj(checkpoint_path, f"discriminator_stabilized_ndepth_{n_depth}_weights_cifar.h5"))

    try:
        pgan.save_weights(opj(checkpoint_path, f"checkpoint_path_stabilized_ndepth_{n_depth}_weights_cifar.ckpt"), save_format="tf")
    except:
        print("[WARNING] Could not save weights!")
