In [1]:
# Implementation of GAN on any suitable dataset.
import tensorflow as tf
from tensorflow import keras
from keras import layers
import numpy as np

In [2]:
# Load the MNIST dataset
(x_train, _), (_, _) = keras.datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
# Normalize the input data
x_train = x_train.astype("float32") / 255.0

In [4]:
# Reshape the input data to add channel dimension for grayscale images
x_train = x_train.reshape((-1, 28, 28, 1))

In [5]:
# Define the generator model
generator = keras.Sequential(
    [
        keras.Input(shape=(100,)),
        layers.Dense(7 * 7 * 128),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
        layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding="same"),
        layers.Conv2D(1, kernel_size=7, padding="same", activation="sigmoid"),
    ],
    name="generator",
)
generator.summary()

Metal device set to: Apple M2

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB

Model: "generator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 6272)              633472    
                                                                 
 reshape (Reshape)           (None, 7, 7, 128)         0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 14, 14, 128)      262272    
 nspose)                                                         
                                                                 
 conv2d_transpose_1 (Conv2DT  (None, 28, 28, 64)       131136    
 ranspose)                                                       
                                                                 
 conv2d (Conv2D)             (None, 28, 28, 1)         3137      
                                               

In [6]:
# Define the discriminator model
discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, kernel_size=3, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, kernel_size=3, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Flatten(),
        layers.Dense(1),
    ],
    name="discriminator",
)
discriminator.summary()

Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_1 (Conv2D)           (None, 14, 14, 64)        640       
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 14, 14, 64)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 7, 7, 128)         73856     
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 7, 7, 128)         0         
                                                                 
 flatten (Flatten)           (None, 6272)              0         
                                                                 
 dense_1 (Dense)             (None, 1)                 6273      
                                                                 
Total params: 80,769
Trainable params: 80,769
Non-tra

In [7]:
# Combine the generator and discriminator into a GAN
gan = keras.Sequential([generator, discriminator])

# Compile the discriminator (as a standalone model)
discriminator.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
)

# Compile the GAN
gan.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.0003),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
)



In [8]:
# Training loop
batch_size = 32
epochs = 10
steps_per_epoch = x_train.shape[0] // batch_size

In [9]:
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    for step in range(steps_per_epoch):
        # Generate random noise as input to the generator
        noise = tf.random.normal(shape=(batch_size, 100))

        # Generate images using the generator
        generated_images = generator(noise)

        # Create a batch by sampling real images from the training set
        real_images = x_train[np.random.choice(x_train.shape[0], size=batch_size, replace=False)]

        # Concatenate real and generated images
        combined_images = tf.concat([generated_images, real_images], axis=0)

        # Labels for generated and real images
        labels = tf.concat(
            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
        )

        # Add random noise to the labels (important trick)
        labels += 0.05 * tf.random.uniform(labels.shape)

        # Train the discriminator
        d_loss = discriminator.train_on_batch(combined_images, labels)

        # Generate noise as input to the GAN
        noise = tf.random.normal(shape=(batch_size, 100))

        # Labels for generated images (trick the discriminator)
        misleading_labels = tf.zeros((batch_size, 1))

        # Train the GAN (generator only)
        g_loss = gan.train_on_batch(noise, misleading_labels)

    # Print the losses
    print(f"Discriminator loss: {d_loss:.4f}")
    print(f"Generator loss: {g_loss:.4f}")
    print()

Epoch 1/10


2023-07-28 23:59:04.308451: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


Discriminator loss: 0.8380
Generator loss: 0.3462

Epoch 2/10


KeyboardInterrupt: 

In [None]:
# Generate some images using the trained generator
num_samples = 10
noise = tf.random.normal(shape=(num_samples, 100))
generated_images = generator.predict(noise)

# Display the generated images
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, num_samples, figsize=(20, 2))
fig.suptitle("Generated Images")

for i in range(num_samples):
    axs[i].imshow(generated_images[i].reshape(28, 28), cmap="gray")
    axs[i].axis("off")

plt.show()