# VQ VAE Training on CelebA

## Phase 1: Train The variational Autoencoder

In [1]:
import itertools
import os
import sys

from classes.VQVAE import VQVAE
from classes.PixelCNN import PixelCNN,TfDistPixelCNN
from utils.callbacks import WandbImagesVQVAE, Save_VQVAE_Weights, Save_PixelCNN_Weights
import tensorflow as tf
from tensorflow import keras
import numpy as np
import wandb
from wandb.keras import WandbCallback
from tensorflow.data import AUTOTUNE

import argparse
from os.path import join as opj
from imutils import paths

In [2]:
#set the first GPU
os.environ["CUDA_VISIBLE_DEVICES"]="0"

print(os.environ.get("CUDA_VISIBLE_DEVICES"))

0


In [3]:
wandb.login()
phase="VQ_VAE_Training"

config={"dataset":"celebA", "type":"VQ-VAE","phase":phase}
images_dir=r"C:\Users\matte\Dataset\tor_vergata\Dataset\Img\img_align_celeba"



[34m[1mwandb[0m: Currently logged in as: [33mmatteoferrante[0m (use `wandb login --relogin` to force relogin)


In [4]:
BS = 256
EPOCHS=10
INIT_LR=1e-4

config["BS"]=BS
config["EPOCHS"]=EPOCHS
config["INIT_LR"]=INIT_LR

## Dataloaders

In [5]:
def load_images(imagePath):
    # read the image from disk, decode it, resize it, and scale the
    # pixels intensities to the range [0, 1]
    image = tf.io.read_file(imagePath)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (128, 128)) / 255.0

    #eventually load other information like attributes here
    
    # return the image and the extra info
    
    
    return image

In [6]:
print("[INFO] loading image paths...")
imagePaths = list(paths.list_images(images_dir))


train_len=int(0.8*len(imagePaths))
val_len=int(0.1*len(imagePaths))
test_len=int(0.1*len(imagePaths))

train_imgs=imagePaths[:train_len]                                #      80% for training
val_imgs=imagePaths[train_len:train_len+val_len]                 #      10% for validation
test_imgs=imagePaths[train_len+val_len:]                         #      10% for testing

print(f"[TRAINING]\t {len(train_imgs)}\n[VALIDATION]\t {len(val_imgs)}\n[TEST]\t\t {len(test_imgs)}")

[INFO] loading image paths...
[TRAINING]	 113455
[VALIDATION]	 14181
[TEST]		 14183


In [7]:
#TRAINING 

train_dataset = tf.data.Dataset.from_tensor_slices(train_imgs)
train_dataset = (train_dataset
    .shuffle(1024)
    .map(load_images, num_parallel_calls=AUTOTUNE)
    .cache()
    .repeat()
    .batch(BS)
    .prefetch(AUTOTUNE)
)

ts=len(train_imgs)//BS

##VALIDATION

val_dataset = tf.data.Dataset.from_tensor_slices(val_imgs)
val_dataset = (val_dataset
    .shuffle(1024)
    .map(load_images, num_parallel_calls=AUTOTUNE)
    .cache()
    .repeat()
    .batch(BS)
    .prefetch(AUTOTUNE)
)

vs=len(val_imgs)//BS

## TEST

test_dataset = tf.data.Dataset.from_tensor_slices(test_imgs)
test_dataset = (test_dataset
    .shuffle(1024)
    .map(load_images, num_parallel_calls=AUTOTUNE)
    .cache()
    .batch(BS)
    .prefetch(AUTOTUNE)
)

## Model Definition

In [8]:
print(f"[INFO] Training VQ_VAE Model")

encoder_architecture=[(0,64),(0,128),(0,256),(0,384),(0,512)]
decoder_architecture=[(0,512),(0,384),(0,256),(0,128),(0,64)]

g=VQVAE((128,128,3),latent_dim=16,num_embeddings=128,train_variance=4,encoder_architecture=encoder_architecture,decoder_architecture=decoder_architecture)

print(g.encoder.summary())

print(g.decoder.summary())


[INFO] Training VQ_VAE Model
[DEBUG] (4.0, 4.0, 512)


TypeError: Dimension value must be integer or None or have an __index__ method, got value '4.0' with type '<class 'float'>'

## Callbacks

In [None]:
model_check= Save_VQVAE_Weights(output_dir="../models", outname="vq_vae", endname="mnist")



es=tf.keras.callbacks.EarlyStopping(
    monitor="loss",
    min_delta=0,
    patience=3,
    verbose=0,
    mode="auto",
    baseline=None,
    restore_best_weights=True,
)


callbacks=[
    WandbImagesVQVAE(test_dataset,sample=False),
    WandbCallback(),
    model_check,
    es,
]


In [None]:
wandb.init(project="TorVergataExperiment-Generative",config=config)

## COMPILE AND TRAIN

In [None]:
g.compile(keras.optimizers.Adam(INIT_LR))


## Phase 2: Train the PixelCNN sampler



In [None]:
def map_models_weights(model_dir):
    files=os.listdir(model_dir)
    d={}
    for f in files:
        if "encoder" in f:
            d["encoder"]=opj(model_dir,f)
        elif ("generator" in f) or ("decoder" in f):
            d["decoder"]=opj(model_dir,f)
        elif "embeddings" in f:
            d["embeddings"]=opj(model_dir,f)
    return d
