In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers
import matplotlib.image as mpimg
import numpy as np
import keras
import os
from collections import defaultdict

physical_devices = tf.config.list_physical_devices('GPU') 
tf.config.experimental.set_memory_growth(physical_devices[0], True)

IndexError: ignored

In [None]:
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(256, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)
discriminator.summary()

latent_dim = 128

generator = keras.Sequential(
    [
        keras.Input(shape=(latent_dim,)),
        # We want to generate 128 coefficients to reshape into a 7x7x128 map
        layers.Dense(7 * 7 * 128),
        layers.LeakyReLU(alpha=0.2),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
    ],
    name="generator",
)
generator.summary()

In [None]:
class GAN(keras.Model):
    def __init__(self, discriminator, generator, latent_dim=128):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def
    (self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
            
        batch_size = tf.shape(real_images)[0]
        

                #####################
                ## TRAIN GENERATOR ##
                #####################

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Generate images starting from random noise
        generated_images = self.generator(random_latent_vectors)
        
        # Concatenate fake and real images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Concatenate fake and real labels
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )
        
        # Add random noise to the labels - important trick! (??)
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            # Make discriminator predict the concatenated images
            predictions = self.discriminator(combined_images)
            # Compute the loss between the prediction of the discriminator and the synthetic labels
            d_loss = self.loss_fn(labels, predictions)
            
        # Compute the gradient of the loss above, wrt the weights of the discriminator
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        
        # Apply backpropagation to the discriminator
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )
        
        
                #####################
                ## TRAIN GENERATOR ##
                #####################
                
        # Sample random points in the latent space 
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))

        # Assemble misleading labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the generator
        with tf.GradientTape() as tape:
            # Make discriminator predict fake images from the generator (starting from random noise)
            predictions = self.discriminator(self.generator(random_latent_vectors))
            
            # Compute the loss between the misleading labels and the prediction of the discriminator 
            # If the discriminator was fooled and predicted "REAL" then the loss will be small, and the generator will be hardly modified by backpropagation
            # If the discriminator was NOT fooled and predicted "FAKE" then loss will be high, and the generator will be greatly modified by backpropagation
            g_loss = self.loss_fn(misleading_labels, predictions) 
            
        # Compute the gradient of the loss above, wrt the weights of the generator 
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        
        # Apply backpropagation to the generator 
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        
        return {"d_loss": d_loss, "g_loss": g_loss}

In [None]:
class GANMonitor(keras.callbacks.Callback):
    def __init__(self, num_img=3, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))
        generated_images = self.model.generator(random_latent_vectors)
        generated_images *= 255
        generated_images.numpy()
        for i in range(self.num_img):
            img = keras.preprocessing.image.array_to_img(generated_images[i])
            img.save("images/generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch))

In [None]:
# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = tf.keras.datasets.fashion_mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)


gan = GAN(discriminator=discriminator, generator=generator)
gan.compile(
    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)

In [None]:
generator.summary()

In [None]:
discriminator.summary()

In [None]:
# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan.fit(dataset, epochs=30, callbacks=[GANMonitor(num_img=3, latent_dim=latent_dim)])

In [None]:
root = "images/"
images_dict = defaultdict(lambda: defaultdict())
for img in os.listdir(root):
    tokens = img.split("_")
    img_number = tokens[-2]
    epoch = tokens[-1].split(".")[0]
    images_dict[int(epoch)][int(img_number)] = img

In [None]:
nepochs = 30
imgs = 3

fig, ax = plt.subplots(nrows=nepochs, ncols=imgs, figsize=(5, 30))
ax = np.ravel(ax)

for i in range(nepochs):
    for j in range(imgs):
        ax[(i*imgs)+j].imshow(mpimg.imread(os.path.join(root, images_dict[i][j])), cmap="Greys")
        ax[(i*imgs)+j].axis('off')

fig.tight_layout()