## VAE model on CelebA dataset

In [10]:
import os
import sys

from classes.VAE import VAE
from utils.callbacks import WandbImagesVAE, SaveGeneratorWeights, SaveVAEWeights
import tensorflow as tf
from tensorflow import keras
import numpy as np
import wandb
from wandb.keras import WandbCallback
from imutils import paths


wandb.login()




True

## Model definition and configurations

In [11]:
encoder_architecture=[(1,64),(1,128),(1,256),(1,512)]
decoder_architecture=[(1,64),(1,128),(1,256),(1,512)]

BS=256
g=VAE((128,128,3),
      latent_dim=100,
      encoder_architecture=encoder_architecture,
      decoder_architecture=decoder_architecture,
      output_channels=3)


config={"dataset":"celebA", "type":"VAE","encoder_architecture":encoder_architecture,"decoder_architecture":decoder_architecture}
config.update(g.get_dict())



In [19]:
images_dir=r"home/matteo/NeuroGEN/Dataset/Img/img_align_celeba"

#other important definitions

EPOCHS=50
BS=64
INIT_LR=1e-3

config["epochs"]=EPOCHS
config["BS"]=BS
config["init_lr"]=INIT_LR


In [13]:
#set the second GPU
os.environ["CUDA_VISIBLE_DEVICES"]="0"

print(os.environ.get("CUDA_VISIBLE_DEVICES"))

0


## Dataloaders

In [14]:
def load_images(imagePath):
    # read the image from disk, decode it, resize it, and scale the
    # pixels intensities to the range [0, 1]
    image = tf.io.read_file(imagePath)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (128, 128)) / 255.0

    #eventually load other information like attributes here
    
    # return the image and the extra info
    
    
    return image

In [None]:
wandb.init(project="TorVergataExperiment-Generative",config=config,name="VAE_CelebA")

In [15]:
print("[INFO] loading image paths...")
imagePaths = list(paths.list_images(images_dir))


train_len=int(0.8*len(imagePaths))
val_len=int(0.1*len(imagePaths))
test_len=int(0.1*len(imagePaths))

train_imgs=imagePaths[:train_len]                                #      80% for training
val_imgs=imagePaths[train_len:train_len+val_len]                 #      10% for validation
test_imgs=imagePaths[train_len+val_len:]                         #      10% for testing

print(f"[TRAINING]\t {len(train_imgs)}\n[VALIDATION]\t {len(val_imgs)}\n[TEST]\t {len(test_imgs)}")

[INFO] loading image paths...
[TRAINING]	 0
[VALIDATION]	 0
[TEST]	 0


In [None]:
#TRAINING 

train_dataset = tf.data.Dataset.from_tensor_slices(train_imgs)
train_dataset = (train_dataset
    .shuffle(1024)
    .map(load_images, num_parallel_calls=AUTOTUNE)
    .cache()
    .repeat()
    .batch(BS)
    .prefetch(AUTOTUNE)
)

##VALIDATION

val_dataset = tf.data.Dataset.from_tensor_slices(val_imgs)
val_dataset = (val_dataset
    .shuffle(1024)
    .map(load_images, num_parallel_calls=AUTOTUNE)
    .cache()
    .repeat()
    .batch(BS)
    .prefetch(AUTOTUNE)
)

## TEST

test_dataset = tf.data.Dataset.from_tensor_slices(test_imgs)
test_dataset = (test_dataset
    .shuffle(1024)
    .map(load_images, num_parallel_calls=AUTOTUNE)
    .cache()
    .repeat()
    .batch(BS)
    .prefetch(AUTOTUNE)
)

In [None]:
print(f"[INFO] Visual check images in dataset")
X=[]

for el in train_dataset:
    X.append(el)
    break
    
fig, axs=plt.subplots(6,6,figsize=(8,8))



for (i,ax) in enumerate(axs.ravel()):
    ax.imshow(X[0][i])

## Model Checkpoint

In [None]:
os.makedirs("models/vae",exist_ok=True)
model_check=SaveVAEWeights(filepath="models/vae")


callbacks=[
    WandbImagesVAE(test_dataset),
    WandbCallback(),
    model_check,
]


## Model Training

In [None]:
g.compile(optimizer=keras.optimizers.Adam(learning_rate=INIT_LR))
g.fit(train_dataset,validation_data=test_dataset,steps_per_epoch=ts,validation_steps=vs,epochs=40,callbacks=callbacks)
