In [None]:
"""
train_vae.py
Full VAE training script for 64x64 anime faces.
Assumes images are in ./dataset/images/ (flat folder).
Saves:
  - checkpoints/decoder.h5
  - checkpoints/encoder.h5
  - checkpoints/vae_epoch_{epoch}.h5 (optional)
  - samples/sample_epoch_{epoch}.png
"""

import os
import math
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, Model

# --------- Config ----------
DATA_DIR = "Datasets/anim_images"       # change if needed
IMG_SIZE = 64
BATCH_SIZE = 128
LATENT_DIM = 128
EPOCHS = 30
CHECKPOINT_DIR = "checkpoints"
SAMPLES_DIR = "samples"
AUTOTUNE = tf.data.AUTOTUNE
# --------------------------

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(SAMPLES_DIR, exist_ok=True)

# --------- Utility: load file paths ----------
def list_image_files(folder):
    exts = (".jpg", ".jpeg", ".png", ".bmp")
    paths = []
    for fname in os.listdir(folder):
        if fname.lower().endswith(exts):
            paths.append(os.path.join(folder, fname))
    return paths

image_paths = list_image_files(DATA_DIR)
print(f"Found {len(image_paths)} images in {DATA_DIR}")
if len(image_paths) == 0:
    raise SystemExit("No images found. Put your images in dataset/images/")

# --------- tf.data pipeline ----------
def decode_and_resize(filename):
    image = tf.io.read_file(filename)
    image = tf.image.decode_image(image, channels=3, expand_animations=False)
    image = tf.image.convert_image_dtype(image, tf.float32)  # [0,1]
    image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
    return image

path_ds = tf.data.Dataset.from_tensor_slices(image_paths)
img_ds = path_ds.map(decode_and_resize, num_parallel_calls=AUTOTUNE)
dataset = img_ds.shuffle(buffer_size=10000).batch(BATCH_SIZE).prefetch(AUTOTUNE)

# --------- VAE model (Encoder, Sampling, Decoder) ----------
# Encoder
encoder_inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = layers.Conv2D(32, 3, strides=2, padding="same", activation="relu")(encoder_inputs)   # 32x32
x = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)               # 16x16
x = layers.Conv2D(128, 3, strides=2, padding="same", activation="relu")(x)              # 8x8
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)

z_mean = layers.Dense(LATENT_DIM, name="z_mean")(x)
z_log_var = layers.Dense(LATENT_DIM, name="z_log_var")(x)

def sampling_layer(args):
    z_mean, z_log_var = args
    eps = tf.random.normal(shape=(tf.shape(z_mean)[0], LATENT_DIM))
    return z_mean + tf.exp(0.5 * z_log_var) * eps

z = layers.Lambda(sampling_layer, output_shape=(LATENT_DIM,))([z_mean, z_log_var])
encoder = Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

# Decoder
latent_inputs = layers.Input(shape=(LATENT_DIM,))
x = layers.Dense(8 * 8 * 128, activation="relu")(latent_inputs)
x = layers.Reshape((8, 8, 128))(x)
x = layers.Conv2DTranspose(128, 3, strides=2, padding="same", activation="relu")(x)   # 16x16
x = layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(x)    # 32x32
x = layers.Conv2DTranspose(32, 3, strides=2, padding="same", activation="relu")(x)    # 64x64
decoder_outputs = layers.Conv2D(3, 3, activation="sigmoid", padding="same")(x)
decoder = Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

# VAE as a Model subclass with custom train_step
class VAE(Model):
    def __init__(self, encoder, decoder, img_size=IMG_SIZE):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.img_size = img_size

    def compile(self, optimizer, recon_loss_fn):
        super(VAE, self).compile()
        self.optimizer = optimizer
        self.recon_loss_fn = recon_loss_fn

    def train_step(self, data):
        images = data
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(images, training=True)
            reconstruction = self.decoder(z, training=True)

            # Reconstruction loss (binary crossentropy used here; multiply by pixels)
            recon_loss = self.recon_loss_fn(tf.reshape(images, [-1, self.img_size*self.img_size*3]),
                                            tf.reshape(reconstruction, [-1, self.img_size*self.img_size*3]))
            recon_loss = tf.reduce_mean(recon_loss) * (self.img_size * self.img_size * 3)

            # KL divergence
            kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))

            total_loss = recon_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        return {"loss": total_loss, "reconstruction_loss": recon_loss, "kl_loss": kl_loss}

# --------- Compile VAE ----------
vae = VAE(encoder, decoder, img_size=IMG_SIZE)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
recon_loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=False, reduction=tf.keras.losses.Reduction.NONE)
vae.compile(optimizer=optimizer, recon_loss_fn=recon_loss_fn)

# --------- Callbacks: save checkpoints & sample images ----------
class SampleCallback(tf.keras.callbacks.Callback):
    def __init__(self, sample_dir, latent_dim, num_samples=8):
        super().__init__()
        self.sample_dir = sample_dir
        self.latent_dim = latent_dim
        self.num_samples = num_samples

    def on_epoch_end(self, epoch, logs=None):
        z = np.random.normal(size=(self.num_samples, self.latent_dim))
        generated = self.model.decoder.predict(z)
        # create a grid and save
        rows = int(math.sqrt(self.num_samples))
        cols = rows
        fig, axs = plt.subplots(rows, cols, figsize=(cols*2, rows*2))
        idx = 0
        for r in range(rows):
            for c in range(cols):
                axs[r,c].imshow(generated[idx])
                axs[r,c].axis('off')
                idx += 1
        save_path = os.path.join(self.sample_dir, f"sample_epoch_{epoch+1}.png")
        plt.tight_layout()
        plt.savefig(save_path)
        plt.close(fig)

sample_cb = SampleCallback(SAMPLES_DIR, LATENT_DIM, num_samples=16)

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(CHECKPOINT_DIR, "vae_epoch_{epoch:02d}.h5"),
    save_weights_only=False,
    save_freq='epoch'
)

# --------- Train ----------
print("Starting training...")
vae.fit(dataset, epochs=EPOCHS, callbacks=[sample_cb, checkpoint_cb])

# Save final encoder & decoder separately for easy loading
encoder.save(os.path.join(CHECKPOINT_DIR, "encoder.h5"))
decoder.save(os.path.join(CHECKPOINT_DIR, "decoder.h5"))
print("Training complete. Saved models and samples.")


Found 21551 images in Datasets/anim_images


Starting training...
Epoch 1/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 53ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 93ms/step - kl_loss: 12.1871 - loss: 7481.4741 - reconstruction_loss: 7469.2871
Epoch 2/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 93ms/step - kl_loss: 5.1205 - loss: 6848.8784 - reconstruction_loss: 6843.7578
Epoch 3/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 95ms/step - kl_loss: 4.9960 - loss: 6776.1719 - reconstruction_loss: 6771.1758
Epoch 4/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 94ms/step - kl_loss: 5.0657 - loss: 6630.5405 - reconstruction_loss: 6625.4746
Epoch 5/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 95ms/step - kl_loss: 5.8830 - loss: 6517.7881 - reconstruction_loss: 6511.9053
Epoch 6/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 96ms/step - kl_loss: 5.4895 - loss: 6501.2041 - reconstruction_loss: 6495.7148
Epoch 7/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 96ms/step - kl_loss: 5.0814 - loss: 6514.0898 - reconstruction_loss: 6509.0083
Epoch 8/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 99ms/step - kl_loss: 5.2015 - loss: 6437.4155 - reconstruction_loss: 6432.2139
Epoch 9/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 97ms/step - kl_loss: 4.9806 - loss: 6369.7256 - reconstruction_loss: 6364.7451
Epoch 10/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 105ms/step - kl_loss: 5.0445 - loss: 6241.8906 - reconstruction_loss: 6236.8462
Epoch 11/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 97ms/step - kl_loss: 4.9693 - loss: 6379.7368 - reconstruction_loss: 6374.7676
Epoch 12/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 96ms/step - kl_loss: 4.8828 - loss: 6308.7646 - reconstruction_loss: 6303.8818
Epoch 13/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 96ms/step - kl_loss: 4.9901 - loss: 6284.0146 - reconstruction_loss: 6279.0244
Epoch 14/30
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step




[1m169/169[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 100ms/step - kl_loss: 4.7958 - loss: 6354.6636 - reconstruction_loss: 6349.8677
Epoch 15/30
[1m139/169[0m [32m━━━━━━━━━━━━━━━━[0m[37m━━━━[0m [1m2s[0m 99ms/step - kl_loss: 4.7920 - loss: 6312.2881 - reconstruction_loss: 6307.4961