In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import math

# load dataset of cat vs dogs for now

In [2]:
IMAGE_HEIGHT = IMAGE_WIDTH = 256
IMAGE_CHANNELS = 3

# dataset = tf.keras.utils.image_dataset_from_directory("./Celeb_images/",
#                                                              label_mode=None,
#                                                                image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
#                                                                  batch_size=128,
#                                                                  shuffle=True,
#                                                                  color_mode="rgb",
#                                                                  smart_resize=True)

dataset = tf.keras.utils.image_dataset_from_directory("./art_images/monet_jpg/",
                                                            label_mode=None,
                                                              image_size=(IMAGE_HEIGHT, IMAGE_WIDTH),
                                                                batch_size=8,
                                                                shuffle=True,
                                                                color_mode="rgb",
                                                                smart_resize=True)

Found 300 files belonging to 1 classes.
Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB



# GANS

## steps:
 - Build an generatr
 - Build a discriminator
 - Using both build the GAN
    - GAN will use the weights of the discriminator to understand how to fool it
    - We want to optimize both the discriminator and the generator. Both needs to get good at what they do in order to suscceed

In [3]:
tf.keras.backend.clear_session()

# Generator

The generator will take a random sample from the latent space and generate an image in this case this could be a text or even audio o some other information format

In [4]:
# Latent space made out of vectors of 128 positions
LATENT_DIMS = 128

NUMBER_CONVTRASPOSELAYERS = 3

CHANNELS_INIT = 128
GEN_INIT_IMAGE_HEIGHT = GEN_INIT_IMAGE_WIDTH = math.ceil(IMAGE_HEIGHT/(2**NUMBER_CONVTRASPOSELAYERS))

KERNEL_SIZE = 4
STRIDE = 2

X_latent_input = tf.keras.Input(shape=(LATENT_DIMS,))

X = tf.keras.layers.Dense(units=(GEN_INIT_IMAGE_HEIGHT * GEN_INIT_IMAGE_WIDTH * CHANNELS_INIT))(X_latent_input)

X = tf.keras.layers.Reshape((GEN_INIT_IMAGE_HEIGHT, GEN_INIT_IMAGE_WIDTH, CHANNELS_INIT))(X)

for i in range(NUMBER_CONVTRASPOSELAYERS):
    X = tf.keras.layers.Conv2DTranspose(CHANNELS_INIT*(2**i), kernel_size=KERNEL_SIZE, strides=STRIDE, padding="same")(X)
    X = tf.keras.layers.LeakyReLU(alpha=0.2)(X)

X_outputs = tf.keras.layers.Conv2D(IMAGE_CHANNELS, kernel_size=5, padding="same", activation="sigmoid")(X)

generator_model = tf.keras.Model(inputs=X_latent_input, outputs=X_outputs, name="generator")

In [5]:
generator_model.summary()

Model: "generator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 128)]             0         
                                                                 
 dense (Dense)               (None, 131072)            16908288  
                                                                 
 reshape (Reshape)           (None, 32, 32, 128)       0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 64, 64, 128)      262272    
 nspose)                                                         
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 64, 64, 128)       0         
                                                                 
 conv2d_transpose_1 (Conv2DT  (None, 128, 128, 256)    524544    
 ranspose)                                               

# Discriminator

The discriminator is just a simple model that classifies an image in 2 classes Fake or real

In [6]:
NUMBER_DISCRIMINATOR_CONVLAYERS = 2

INIT_DISCRIMINATOR_FILTERS = 64

X_discriminator_input = tf.keras.Input(shape=(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS))

X = tf.keras.layers.Conv2D(filters=INIT_DISCRIMINATOR_FILTERS, kernel_size=KERNEL_SIZE, strides=STRIDE, padding="same")(X_discriminator_input)
X = tf.keras.layers.LeakyReLU(alpha=0.2)(X)

for i in range(NUMBER_CONVTRASPOSELAYERS-1):
    # in the book never changes the filters maybe this is a mistake I gonna keep it for now
    X = tf.keras.layers.Conv2D(filters=INIT_DISCRIMINATOR_FILTERS*2, kernel_size=KERNEL_SIZE, strides=STRIDE, padding="same")(X)
    # X = tf.keras.layers.Conv2D(filters=INIT_DISCRIMINATOR_FILTERS*(2**(i+1)), kernel_size=KERNEL_SIZE, strides=STRIDE, padding="same")(X)
    X = tf.keras.layers.LeakyReLU(alpha=0.2)(X)

X = tf.keras.layers.Flatten()(X)

X = tf.keras.layers.Dropout(0.4)(X)

X_output_real_or_fake = tf.keras.layers.Dense(units=1, activation="sigmoid")(X)

discriminator_model = tf.keras.Model(inputs=X_discriminator_input, outputs=X_output_real_or_fake, name="discriminator")

In [7]:
discriminator_model.summary()

Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 256, 256, 3)]     0         
                                                                 
 conv2d_1 (Conv2D)           (None, 128, 128, 64)      3136      
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 128, 128, 64)      0         
                                                                 
 conv2d_2 (Conv2D)           (None, 64, 64, 128)       131200    
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 64, 64, 128)       0         
                                                                 
 conv2d_3 (Conv2D)           (None, 32, 32, 128)       262272    
                                                                 
 leaky_re_lu_5 (LeakyReLU)   (None, 32, 32, 128)     

# Adversarial network model

Well this is the complicate part

In [8]:
class GAN(tf.keras.Model):
    def __init__(self, generator, discriminator, latent_dims, **kwargs):
        super().__init__(**kwargs)
        self.generator = generator
        self.discriminator = discriminator
        # this is to sample a random value from a normal distribution to feed the generator
        self.latent_dims = latent_dims
        self.d_loss_metric = tf.keras.metrics.Mean(name="discriminator_loss")
        self.g_loss_metric = tf.keras.metrics.Mean(name="generator_loss")
    
    def compile(self, d_optimizer, g_optimizer, loss_function, **kwargs):
        """ 
            each model will have its own optimizer
        """
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_function = loss_function

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def train_step(self, batch_data):
        batch_size = tf.shape(batch_data)[0]

        #generate random a batch of random vectors from a normal distribution of our "latent" space
        # the model is learning
        latent_random_vector = tf.random.normal(shape=(batch_size, self.latent_dims))
        
        batch_of_generated_images = self.generator(latent_random_vector)

        # Now we need to build a dataset with fake and real images with their labels.
        combined_batch_of_images = tf.concat([batch_of_generated_images, batch_data], axis=0)
        
        # the first ones are the fake ones thus all of those are 1 the second one are real thus 0
        labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)

        # this step is a trick it adds noise to the labels a uniform noise
        #TODO: Will this works better with gausian noise added instead?
        labels += 0.05 * tf.random.uniform(shape=tf.shape(labels))

        # Fist optimize the DISCRIMINATOR
        with tf.GradientTape() as tape:
            discriminator_response = self.discriminator(combined_batch_of_images)
            # how good on average the discriminator got spot the fake one vs the real one
            d_loss = self.loss_function(labels, discriminator_response)

        # we are going to use the discriminator weights to update the generator to undestand how to fool it 
        gradients_discriminator = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(gradients_discriminator, self.discriminator.trainable_weights))

        # TODO : May I use the same batch of images already generated?
        #OPTIMIZE THE GENERATOR
        # sample a random vector from the latent space from a normal distribution
        #NO!: I guess for that graph to be capture I need to re do it insie a GradientTape
        sample_random_vector_z = tf.random.normal(shape=(batch_size, self.latent_dims))
        # WHY?
        # this is a way to fool the discriminator to adjust the weights in a way that will be possible to generate images to fool it at leas next time
        misleading_labels = tf.zeros((batch_size, 1))

        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(sample_random_vector_z))
            g_loss = self.loss_function(misleading_labels, predictions)

        # DISCRIMINATOR LOSS over the GENERATOR weights!
        # This is key the Generator is learning how to fool the DISCRIMINATOR!
        gradients_generator = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(gradients_generator, self.generator.trainable_weights))

        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)

        return {"discriminator_loss": self.d_loss_metric.result(),
                "generator_loss": self.g_loss_metric.result()}

# Callback to see what is going on while training

In [9]:
class GANMonitor(tf.keras.callbacks.Callback):
    def __init__(self, num_img=3, latent_dim=LATENT_DIMS):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(
            shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images *= 255 
        generated_images.numpy()
        for i in range(self.num_img):
            img = tf.keras.utils.array_to_img(generated_images[i])
            img.save(f"./gen_gan_images/generated_img_{epoch:03d}_{i}.png")

In [10]:
tf.keras.backend.clear_session()
gan = GAN(generator=generator_model, discriminator=discriminator_model, latent_dims=LATENT_DIMS)

In [11]:
gan.compile(d_optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=1e-5), g_optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=1e-4), loss_function=tf.keras.losses.BinaryCrossentropy())

In [12]:
gan.fit(dataset, epochs=100, callbacks=[GANMonitor(num_img=3, latent_dim=LATENT_DIMS)])

Epoch 1/100


2023-05-20 23:12:20.844928: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 7

<keras.callbacks.History at 0x2c2456c20>