In [None]:
'''
data: MNIST only 6
  input: 28x28 pixels of value (0->1)
  output: encode and decode MNIST

framework: tf.keras
model: VAE
  layers:
    encoder
    decoder
  params:
  hyperparams:
  algorithm: VAE-2 MLP

result: works, smoother and easier than vanilla GAN
  test: 
  5 iterations for batch = 32 in dataset (~900) -> convergence already. Gen image look really good
'''

In [None]:
import tensorflow as tf
import numpy as np
import struct
import matplotlib.pyplot as plt

with open("10kimages.idx3-ubyte", "rb") as file:
    magic, num, rows, cols = struct.unpack(">IIII", file.read(16))
    data = np.frombuffer(file.read(), dtype=np.uint8)
    features = data.reshape(num, rows * cols).astype(np.float32) / 255.0

with open("10klabels.idx1-ubyte", "rb") as file:
    magic, num = struct.unpack(">II", file.read(8))
    labels = np.frombuffer(file.read(), dtype=np.uint8)

indices = np.where(labels == 6)[0]
images = features[indices]

In [None]:
image_size = 28*28
latent_size = 28 # mu and log sigma^2
batch_size = 32

a = images[4].reshape(28,28)
plt.imshow(a, cmap = 'gray')

dataset = tf.data.Dataset.from_tensor_slices(images).shuffle(images.shape[0]).batch(batch_size)

In [None]:
def models():
    global encoder, enopti, decoder, deopti

    encoder = tf.keras.models.Sequential([
        tf.keras.Input(shape=(image_size,)),
        tf.keras.layers.Dense(512, activation="relu"),
        tf.keras.layers.Dense(256, activation="relu"),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dense(latent_size * 2)])
    enopti = tf.keras.optimizers.Adam()

    decoder = tf.keras.models.Sequential([
        tf.keras.Input(shape=(latent_size,)),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dense(256, activation="relu"),
        tf.keras.layers.Dense(512, activation="relu"),
        tf.keras.layers.Dense(image_size, activation="sigmoid")])
    deopti = tf.keras.optimizers.Adam()

models()

def parameterization_trick(mu, log_sigma2):
    eps = tf.random.normal(shape=tf.shape(mu))
    sigma = tf.exp(0.5 * log_sigma2)
    return mu + (sigma * eps)


In [None]:
def parameterization_trick(mu, log_sigma2):
    eps = tf.random.normal(shape=tf.shape(mu))
    sigma = tf.exp(0.5 * log_sigma2)
    return mu + (sigma * eps)

def train(iterations):
    losses = []
    for iteration in range (iterations + 1):
        total_loss = 0
        for batch in dataset:
            with tf.GradientTape() as tape:
                z_notrick = encoder(batch)
                mu, log_sigma2 = z_notrick[..., :latent_size], z_notrick[..., latent_size:]
                z = parameterization_trick(mu, log_sigma2)
                reconstructed_x = decoder(z)

                reconstruction_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(batch, reconstructed_x))
                kl_loss = -0.5 * tf.reduce_mean(1 + log_sigma2 - tf.square(mu) - tf.exp(log_sigma2))
                loss = reconstruction_loss + kl_loss
            gradients = tape.gradient(loss, encoder.trainable_variables + decoder.trainable_variables)
            enopti.apply_gradients(zip(gradients[:len(encoder.trainable_variables)], encoder.trainable_variables))
            deopti.apply_gradients(zip(gradients[len(encoder.trainable_variables):], decoder.trainable_variables))
            
            total_loss += loss.numpy()
        avg_loss = total_loss / batch_size
        losses.append(avg_loss)

        if (iteration % 1 == 0):
            print(f"Iteration {iteration}: Loss = {avg_loss:.4f}")

    return losses

losses = train(5)

In [None]:
def ploss():
    plt.plot(losses, label="Loss")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.title("Training Loss Curve")
    plt.legend()
    plt.grid()
ploss()

In [None]:
def gauss(num_samples=1000):
    z = tf.random.normal(shape=(num_samples, latent_size)).numpy().flatten()

    plt.hist(z, bins=28)
    plt.xlabel("Value z")
    plt.ylabel("Frequency")

gauss()

In [None]:
def gauss_image(rows, cols):
    num_images = rows * cols
    z = tf.random.normal(shape=(num_images, latent_size))
    reconstructed_x = decoder(z).numpy().reshape(num_images, 28, 28)

    fig, axes = plt.subplots(rows, cols, figsize=(cols * 2, rows * 2))
    for i, ax in enumerate(axes.flat):
        ax.imshow(reconstructed_x[i], cmap="gray")
        ax.axis("off")

gauss_image(2, 4)

In [None]:
def my_z(idx):
    img = images[idx:idx+1]
    z_notrick = encoder(img)
    mu, log_sigma2 = z_notrick[..., :latent_size], z_notrick[..., latent_size:]
    z = parameterization_trick(mu, log_sigma2).numpy().flatten()

    plt.figure(figsize=(6, 4))
    plt.hist(z, bins=28)
    plt.xlabel("Valuez z")
    plt.ylabel("Frequency")

my_z(0) 

In [None]:
def my_z_image(idx=0):
    img = images[idx:idx+1]
    z_notrick = encoder(img)
    mu, log_sigma2 = z_notrick[..., :latent_size], z_notrick[..., latent_size:]
    z = parameterization_trick(mu, log_sigma2)

    reconstructed_x = decoder(z)
    reconstructed_img = reconstructed_x.numpy().reshape(rows, cols)

    plt.figure(figsize=(4, 4))
    plt.imshow(reconstructed_img, cmap="gray")
    plt.axis("off")
    plt.show()

my_z_image(0)