# 🧱 WGAN - Bricks Data

## Train your own Wasserstein GAN on the **BRICKS DATASET**

Code adapted from:
-  [WGAN-GP tutorial](https://keras.io/examples/generative/wgan_gp/)
- DCGAN - Bricks Data
- WGAN - CelebA Faces

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np

import tensorflow as tf
from tensorflow.keras import (
    layers,
    models,
    callbacks,
    utils,
    metrics,
    optimizers,
)

import os
import sys
module_path = os.path.abspath('c:\_AOStuff\Development\ml-ai\generative2')
if module_path not in sys.path:
    sys.path.append(module_path)

print(sys.path)

from notebooks.utils import display, sample_batch

## Customize Setup

In [None]:
! pip install --no-index -f https://github.com/dreoporto/ptmlib/releases ptmlib

In [None]:
from ptmlib.time import Stopwatch, AlertSounds
import ptmlib.charts as pch

In [None]:
main_stopwatch = Stopwatch()
stopwatch = Stopwatch()

main_stopwatch.start()

In [None]:
# AOPORTO MODS

import os
# import sys

# module_path = os.path.abspath('/app/notebooks/03_vae/03_vae_faces')
# if module_path not in sys.path:
#     sys.path.append(module_path)

# if not os.path.exists('./output'):
#     os.mkdir('./output')

def make_missing_directories(directory_file_full_path: str) -> None:
    """
    Create any required directories for the provided directory/file path if they do not exist

    :param directory_file_full_path: full path for the directory or file we will be using
    :return: None
    """
    directory_path = os.path.dirname(directory_file_full_path)

    if directory_path == '':
        # no directory specified
        return
    
    if os.path.exists(directory_path):
        # all folders already exist
        return
    
    os.makedirs(directory_path)
    print(f'one or more folders were created: {directory_path}')

make_missing_directories('./output/buzz.txt')


## 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = 64
CHANNELS = 1
BATCH_SIZE = 64 # 512
# NUM_FEATURES = 64     # NOT USED
Z_DIM = 128 # 100 # 192 # 128 # LATENT DIM
LEARNING_RATE = 0.0002 # 0.0002
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.9 # 0.999
EPOCHS = 20 # 300
CRITIC_STEPS = 3
GP_WEIGHT = 10.0
LOAD_MODEL = False
# ADAM_BETA_1 = 0.5
# ADAM_BETA_2 = 0.9
NOISE_PARAM = 0.1

# only show every X images (images are still generated and saved!)
DISPLAY_RATE = 2

## 1. Prepare the data <a name="prepare"></a>

In [None]:
stopwatch.start()

train_data = utils.image_dataset_from_directory(
    "../../../data/lego-brick-images/dataset/",
    labels=None,
    color_mode="grayscale",
    image_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    shuffle=True,
    seed=42,
    validation_split = 0.75,  # get a subset of images to reduce training time
    subset='training',  # get a subset of images to reduce training time
    interpolation="bilinear",
)

stopwatch.stop()

In [None]:
# Preprocess the data
def preprocess(img):
    """
    Normalize and reshape the images
    """
    img = (tf.cast(img, "float32") - 127.5) / 127.5
    return img


train = train_data.map(lambda x: preprocess(x))

In [None]:
# Show some faces from the training set
train_sample = sample_batch(train)

In [None]:
print(len(train_sample))

my_take = train.take(1)

print(len(my_take))
print(type(my_take))

my_sample = my_take.get_single_element()

print(len(my_sample))
print(type(my_sample))
print(tf.shape(my_sample))

print(my_sample[0])

In [None]:
display(train_sample)

## 2. Build the WGAN-GP <a name="build"></a>

In [None]:
critic_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, CHANNELS))
x = layers.Conv2D(64, kernel_size=4, strides=2, padding="same")(critic_input)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2D(128, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU()(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(256, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(512, kernel_size=4, strides=2, padding="same")(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Dropout(0.3)(x)
x = layers.Conv2D(1, kernel_size=4, strides=1, padding="valid")(x)
critic_output = layers.Flatten()(x)

critic = models.Model(critic_input, critic_output)
critic.summary()

In [None]:
generator_input = layers.Input(shape=(Z_DIM,))
x = layers.Reshape((1, 1, Z_DIM))(generator_input)
x = layers.Conv2DTranspose(
    512, kernel_size=4, strides=1, padding="valid", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    256, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    128, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
x = layers.Conv2DTranspose(
    64, kernel_size=4, strides=2, padding="same", use_bias=False
)(x)
x = layers.BatchNormalization(momentum=0.9)(x)
x = layers.LeakyReLU(0.2)(x)
generator_output = layers.Conv2DTranspose(
    CHANNELS, kernel_size=4, strides=2, padding="same", activation="tanh"
)(x)
generator = models.Model(generator_input, generator_output)
generator.summary()

In [None]:
class WGANGP(models.Model):
    def __init__(self, critic, generator, latent_dim, critic_steps, gp_weight):
        super(WGANGP, self).__init__()
        self.critic = critic
        self.generator = generator
        self.latent_dim = latent_dim
        self.critic_steps = critic_steps
        self.gp_weight = gp_weight

    def compile(self, c_optimizer, g_optimizer):
        super(WGANGP, self).compile()
        self.c_optimizer = c_optimizer
        self.g_optimizer = g_optimizer
        self.c_wass_loss_metric = metrics.Mean(name="c_wass_loss")
        self.c_gp_metric = metrics.Mean(name="c_gp")
        self.c_loss_metric = metrics.Mean(name="c_loss")
        self.g_loss_metric = metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [
            self.c_loss_metric,
            self.c_wass_loss_metric,
            self.c_gp_metric,
            self.g_loss_metric,
        ]

    def gradient_penalty(self, batch_size, real_images, fake_images):
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            pred = self.critic(interpolated, training=True)

        grads = gp_tape.gradient(pred, [interpolated])[0]
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, real_images):
        batch_size = tf.shape(real_images)[0]

        for i in range(self.critic_steps):
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )

            with tf.GradientTape() as tape:
                fake_images = self.generator(
                    random_latent_vectors, training=True
                )
                fake_predictions = self.critic(fake_images, training=True)
                real_predictions = self.critic(real_images, training=True)

                c_wass_loss = tf.reduce_mean(fake_predictions) - tf.reduce_mean(
                    real_predictions
                )
                c_gp = self.gradient_penalty(
                    batch_size, real_images, fake_images
                )
                c_loss = c_wass_loss + c_gp * self.gp_weight

            c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
            self.c_optimizer.apply_gradients(
                zip(c_gradient, self.critic.trainable_variables)
            )

        random_latent_vectors = tf.random.normal(
            shape=(batch_size, self.latent_dim)
        )
        with tf.GradientTape() as tape:
            fake_images = self.generator(random_latent_vectors, training=True)
            fake_predictions = self.critic(fake_images, training=True)
            g_loss = -tf.reduce_mean(fake_predictions)

        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )

        self.c_loss_metric.update_state(c_loss)
        self.c_wass_loss_metric.update_state(c_wass_loss)
        self.c_gp_metric.update_state(c_gp)
        self.g_loss_metric.update_state(g_loss)

        return {m.name: m.result() for m in self.metrics}

In [None]:
# Create a GAN
wgangp = WGANGP(
    critic=critic,
    generator=generator,
    latent_dim=Z_DIM,
    critic_steps=CRITIC_STEPS,
    gp_weight=GP_WEIGHT,
)

In [None]:
if LOAD_MODEL:
    wgangp.load_weights("./checkpoint/checkpoint.ckpt")

## 3. Train the GAN <a name="train"></a>

In [None]:
# Compile the GAN

wgangp.compile(
    c_optimizer=optimizers.Adam(
        learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2
    ),
    g_optimizer=optimizers.Adam(
        learning_rate=LEARNING_RATE, beta_1=ADAM_BETA_1, beta_2=ADAM_BETA_2
    ),
)

In [None]:
# Create a model save checkpoint
model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath="./checkpoint/checkpoint.ckpt",
    save_weights_only=True,
    save_freq="epoch",
    verbose=0,
)

tensorboard_callback = callbacks.TensorBoard(log_dir="./logs")

class ImageGenerator(callbacks.Callback):
    def __init__(self, num_img, latent_dim, display_rate):
        self.num_img = num_img
        self.latent_dim = latent_dim
        self.display_rate = display_rate

    def on_epoch_end(self, epoch, logs=None):

        # only show every X images (images are still generated and saved!)
        # using epoch + 1 since notebook output shows 'Epoch 1/x'
        # always show first image (epoch == 0)
        show_image = epoch == 0 or (epoch + 1) % self.display_rate  == 0
        
        random_latent_vectors = tf.random.normal(
            shape=(self.num_img, self.latent_dim)
        )
        generated_images = self.model.generator(random_latent_vectors)
        generated_images = generated_images * 127.5 + 127.5
        generated_images = generated_images.numpy()
        display(
            generated_images,
            save_to="./output/WGAN_GP_BRICKS__generated_img_%03d.png" % (epoch + 1),
            # cmap=None,
            show_image=show_image
        )

In [None]:
stopwatch.start()

history = wgangp.fit(
    train,
    epochs=EPOCHS,
    # steps_per_epoch=2,
    callbacks=[
        model_checkpoint_callback,
        tensorboard_callback,
        ImageGenerator(num_img=10, latent_dim=Z_DIM, display_rate = DISPLAY_RATE),
    ],
)

stopwatch.stop()

### Save the Models 💃🕺

In [None]:
# Save the final models
generator.save("./models/generator")
critic.save("./models/critic")

## Generate images

In [None]:
z_sample = np.random.normal(size=(10, Z_DIM))
imgs = wgangp.generator.predict(z_sample)
display(imgs) #, cmap=None)

In [None]:
main_stopwatch.stop()

In [None]:
pch.show_history_chart(history, "loss", save_fig_enabled=False)

In [None]:
pch.show_history_chart(history, "gp", save_fig_enabled=False)