# GAN Example

In [1]:
import tensorflow as tf

from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout, BatchNormalization, MaxPooling2D, LeakyReLU
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras import ops
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
from tensorflow.keras.metrics import Mean

import numpy as np
import matplotlib.pyplot as plt
import sys, os

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [3]:
# Centriamo in 0, tra -1 e 1
x_train = (x_train / 255.0) * 2 - 1 
x_test = (x_test / 255.0) * 2 - 1 

In [4]:
N, H, W = x_train.shape
D = H * W # Facciamo flattening
print(N, H, W, D)

60000 28 28 784


In [5]:
x_train = x_train.reshape(-1, D)
x_test = x_test.reshape(-1, D)

In [6]:
latent_dim = 100

## Build Models

In [7]:
generator = Sequential(
    [
        Input(shape=(latent_dim,)),
        Dense(256, activation=LeakyReLU(negative_slope=0.2)),
        BatchNormalization(momentum=0.7),
        Dense(512, activation=LeakyReLU(negative_slope=0.2)),
        BatchNormalization(momentum=0.7),
        Dense(1024, activation=LeakyReLU(negative_slope=0.2)),
        BatchNormalization(momentum=0.7),
        Dense(D, activation="tanh")
    ],
    name="generator",
)
generator.summary()

In [8]:
discriminator = Sequential(
    [
        Input(shape=(D,)),
        Dense(512, activation=LeakyReLU(negative_slope=0.2)),
        Dense(256, activation=LeakyReLU(negative_slope=0.2)),
        Dense(1, activation="sigmoid")
    ],
    name="discriminator",
)
discriminator.summary()

In [9]:
class GAN(Model):
    def __init__(self, discriminator, generator, latent_dim):
        super().__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.seed_generator = tf.keras.random.SeedGenerator(42)

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        self.d_loss_metric = Mean(name="d_loss")
        self.g_loss_metric = Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def train_step(self, real_images):
        # Sample random points in the latent space
        batch_size = ops.shape(real_images)[0]
        random_latent_vectors = tf.keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)

        # Combine them with real images
        combined_images = ops.concatenate([generated_images, real_images], axis=0)

        # Assemble labels discriminating real from fake images
        labels = ops.concatenate(
            [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space
        random_latent_vectors = tf.keras.random.normal(
            shape=(batch_size, self.latent_dim), seed=self.seed_generator
        )

        # Assemble labels that say "all real images"
        misleading_labels = ops.zeros((batch_size, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Update metrics
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)
        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }

In [10]:
class GANMonitor(tf.keras.callbacks.Callback):

    def on_epoch_end(self, epoch, logs=None):
        rows, cols = 5, 5 
        noise = np.random.randn(rows*cols, latent_dim) # Creiamo 25 vettori latenti, in poche parole. N x latent_dim
        imgs = generator.predict(noise)
    
        imgs = imgs * 0.5 + 0.5
        fig, axes = plt.subplots(rows, cols)
        
        for i in range(rows):
            for j in range(cols):
                idx = i*cols +  j
                axes[i, j].imshow(imgs[idx].reshape(H, W), cmap="gray")
                axes[i, j].axis("off")
    
        fig.savefig(f"results/{epoch}.png")
        plt.close()

## Train GAN

In [11]:
# confs
batch_size = 32
epochs = 30_000
sample_period = 200 # Usiamo questa variabile per generare e salvare delle immagini ogni x epochs

epochs = 25  # In practice, use ~100 epochs

gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
    d_optimizer=Adam(learning_rate=0.0001),
    g_optimizer=Adam(learning_rate=0.0001),
    loss_fn=tf.keras.losses.BinaryCrossentropy(),
)

gan.fit(
    x_train, epochs=epochs, callbacks=[GANMonitor()]
)

Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 94ms/steps/step - d_loss: 0.4090 - g_loss: 2.006
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m35s[0m 17ms/step - d_loss: 0.4090 - g_loss: 2.0063
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 30ms/steps/step - d_loss: 0.5690 - g_loss: 1.761
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 17ms/step - d_loss: 0.5689 - g_loss: 1.7616
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/steps/step - d_loss: 0.3650 - g_loss: 2.270
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 18ms/step - d_loss: 0.3650 - g_loss: 2.2710
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/steps/step - d_loss: 0.2611 - g_loss: 3.060
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m33s[0m 17ms/step - d_loss: 0.2612 - g_loss: 3.0604
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[3

KeyboardInterrupt: 