# GAN Example

In [1]:
import tensorflow as tf

from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Dropout, BatchNormalization, MaxPooling2D, LeakyReLU
from tensorflow.keras.models import Model
from tensorflow.keras import ops
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist

import numpy as np
import matplotlib.pyplot as plt
import sys, os

In [2]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [3]:
# Centriamo in 0, tra -1 e 1
x_train = (x_train / 255.0) * 2 - 1 
x_test = (x_test / 255.0) * 2 - 1 

In [4]:
N, H, W = x_train.shape
D = H * W # Facciamo flattening
print(N, H, W, D)

60000 28 28 784


In [5]:
x_train = x_train.reshape(-1, D)
x_test = x_test.reshape(-1, D)

In [6]:
latent_dim = 100

## Build Models

In [7]:
def build_generator(latent_dim):
    i = Input(shape=(latent_dim,))
    o = Dense(256, activation=LeakyReLU(negative_slope=0.2))(i)
    #o = BatchNormalization(momentum=0.8)(o)
    o = Dense(512, activation=LeakyReLU(negative_slope=0.2))(o)
    #o = BatchNormalization(momentum=0.8)(o)
    o = Dense(1024, activation=LeakyReLU(negative_slope=0.2))(o)
    #o = BatchNormalization(momentum=0.8)(o)
    o = Dense(D, activation="tanh")(o) # Perchè tanh? Perchè noi stiamo usando immagini comprese tra -1 e 1, quindi dobbiamo generare valori nello stesso range

    model = Model(i, o)
    return model

In [8]:
def build_discriminator(image_size):
    i = Input(shape=(image_size,))
    o = Dense(64, activation=LeakyReLU(negative_slope=0.2))(i)
    o = Dense(16, activation=LeakyReLU(negative_slope=0.2))(o)
    o = Dense(1, activation="sigmoid")(o)

    model = Model(i, o)
    return model

In [9]:
discriminator = build_discriminator(D)
discriminator.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=Adam(0.0006)
    #metrics=[tf.keras.metrics.Accuracy()]
)

discriminator.summary()

In [10]:
generator = build_generator(latent_dim)

generator.compile(
    loss=tf.keras.losses.BinaryCrossentropy(),
    optimizer=Adam(0.0001)
)

In [11]:
generator.summary()

## Train GAN

In [12]:
# confs
batch_size = 32
epochs = 30_000
sample_period = 200 # Usiamo questa variabile per generare e salvare delle immagini ogni x epochs

# Creiamo delle labels di dimensione "batch_size"
ones = np.ones(batch_size) # Quando trainiamo il generatore (fake images) o quando trainiamo il discriminatore (real images)
zeros = np.zeros(batch_size) # Quando trainiamo il discriminatore (fake images)

# Storiamo le losses for fun
d_losses = []
g_losses = []

# Storiamo le imagini per epoch
if not os.path.isdir("results"):
    os.mkdir("results")

In [13]:
# Con questa funzione campioniamo lo spazio latente per generare immagini -> le salviamo in file anche
def sample_images(epoch):
    rows, cols = 5, 5 
    noise = np.random.randn(rows*cols, latent_dim) # Creiamo 25 vettori latenti, in poche parole. N x latent_dim
    imgs = generator.predict(noise)

    # riscaliamo le immagini tra 0 e 1 facendo operazioni inverse
    imgs = imgs * 0.5 + 0.5

    fig, axes = plt.subplots(rows, cols)
    
    for i in range(rows):
        for j in range(cols):
            idx = i*cols +  j
            axes[i, j].imshow(imgs[idx].reshape(H, W), cmap="gray")
            axes[i, j].axis("off")

    fig.savefig(f"results/{epoch}.png")
    plt.close()

In [14]:
binary_crossentropy = tf.keras.losses.BinaryCrossentropy()
accuracy = tf.keras.metrics.Accuracy()

In [15]:
# Training Loop

for epoch in range(epochs):
    ###########################
    ### TRAIN DISCRIMINATOR ###
    ###########################

    # Immagini Reali
    idx = np.random.randint(0, x_train.shape[0], batch_size) # Campioniamo batch_size interi tra 0 e x_train.shape
    real_images = x_train[idx]

    # Immagini Fake
    noise = np.random.randn(batch_size, latent_dim)
    fake_images = generator.predict(noise, verbose=False)

    concat_images = ops.concatenate([real_images, fake_images], axis=0)
    concat_labels = ops.concatenate([ones, zeros], axis=0)

    with tf.GradientTape() as tape:
        pred = discriminator(concat_images)
        d_loss = binary_crossentropy(pred, concat_labels)
    grads = tape.gradient(d_loss, discriminator.trainable_weights)
    discriminator.optimizer.apply_gradients(zip(grads, discriminator.trainable_weights))
    d_acc = accuracy(pred, concat_labels)

    ###########################
    ##### TRAIN GENERATOR #####
    ###########################
    
    # Generiamo immagini fake e passiamo come label "1". Questo perchè se il discriminatore predice real (1) la loss al generatore sarà 0
    # altrimenti sarà il contrario: se il discriminatore predice fake (0), la loss al generatore sarà alta
    noise = np.random.randn(batch_size, latent_dim)
    noise2 = np.random.randn(batch_size, latent_dim)
    noise = ops.concatenate([noise, noise2], axis=0)
    labels = ops.concatenate([ones, ones], axis=0)
    with tf.GradientTape() as tape:
        g_pred = discriminator(generator(noise))
        g_loss = binary_crossentropy(g_pred, labels)
    grads = tape.gradient(g_loss, generator.trainable_weights)
    generator.optimizer.apply_gradients(zip(grads, generator.trainable_weights))

    g_losses.append(g_loss.numpy())
    d_losses.append(d_loss.numpy())

    if epoch % 20 == 0:
        print(f"Epoch {epoch+1}/{epochs} ---- Discriminator: loss -> {d_loss}; acc -> {d_acc} ||| Generator: loss -> {g_loss}")

    if epoch % sample_period == 0:
        sample_images(epoch)

Epoch 1/30000 ---- Discriminator: loss -> 6.803585529327393 ||| Generator: loss -> 7.497342586517334
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 71ms/step
Epoch 21/30000 ---- Discriminator: loss -> 7.993697643280029 ||| Generator: loss -> 0.12782606482505798
Epoch 41/30000 ---- Discriminator: loss -> 8.025588989257812 ||| Generator: loss -> 0.09540340304374695
Epoch 61/30000 ---- Discriminator: loss -> 6.558298110961914 ||| Generator: loss -> 8.006669044494629
Epoch 81/30000 ---- Discriminator: loss -> 0.06761403381824493 ||| Generator: loss -> 15.849763870239258
Epoch 101/30000 ---- Discriminator: loss -> 0.029545657336711884 ||| Generator: loss -> 15.902130126953125
Epoch 121/30000 ---- Discriminator: loss -> 0.01904294826090336 ||| Generator: loss -> 15.909524917602539
Epoch 141/30000 ---- Discriminator: loss -> 0.04239913821220398 ||| Generator: loss -> 15.893034934997559
Epoch 161/30000 ---- Discriminator: loss -> 0.059228766709566116 ||| Generator: loss -> 15.90

KeyboardInterrupt: 