In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

In [29]:
# Define the generator
def build_generator(NOISE_DIM):
    model = tf.keras.Sequential([
        layers.Dense(7*7*256, use_bias=False, input_shape=(NOISE_DIM,)),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((7, 7, 256)),
        layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'), # use tanh because the image is scaled to -1 to 1
        ])
    return model  # output shape is (batch, 28, 28, 1) same as train data shape.

In [30]:
# Define the discriminator
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(1), # you could choose sigmoid, but not needed since the output will eventually be around 0 to 1.
        # If using sigmoid, set from_logits=False in loss function.
        ])
    return model

In [40]:
class GAN:
    def __init__(self, latent_dim):
        self.latent_dim = latent_dim
        self.generator = build_generator(latent_dim)
        self.discriminator = build_discriminator()

        # Define loss functions
        self.cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

        # Define optimizers
        self.generator_optimizer = tf.keras.optimizers.Adam(1e-4)
        self.discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

    def generator_loss(self, fake_output):
        return self.cross_entropy(tf.ones_like(fake_output), fake_output)

    def discriminator_loss(self, real_output, fake_output):
        real_loss = self.cross_entropy(tf.ones_like(real_output), real_output)
        fake_loss = self.cross_entropy(tf.zeros_like(fake_output), fake_output)
        return real_loss + fake_loss

    @tf.function  # optional but will improve runtime speed.
    def train_step(self, images):
        batch_size = tf.shape(images)[0]

        # Generate random noise
        noise = tf.random.normal([batch_size, self.latent_dim])

        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            # Generate fake images
            generated_images = self.generator(noise, training=True)

            # Get discriminator decisions
            real_output = self.discriminator(images, training=True)
            fake_output = self.discriminator(generated_images, training=True)

            # Calculate losses
            gen_loss = self.generator_loss(fake_output)
            disc_loss = self.discriminator_loss(real_output, fake_output)

        # Calculate gradients
        gen_gradients = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        disc_gradients = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)

        # Apply gradients
        self.generator_optimizer.apply_gradients(zip(gen_gradients, self.generator.trainable_variables))
        self.discriminator_optimizer.apply_gradients(zip(disc_gradients, self.discriminator.trainable_variables))

        return gen_loss, disc_loss

    def generate_and_save_images(self, epoch, test_input):
        predictions = self.generator(test_input, training=False)

        fig = plt.figure(figsize=(4, 4))
        for i in range(predictions.shape[0]):
            plt.subplot(4, 4, i+1)
            plt.imshow(predictions[i, :, :, 0] * 0.5 + 0.5, cmap='gray')
            plt.axis('off')

        #plt.savefig(f'image_at_epoch_{epoch:04d}.png')
        plt.close()

In [32]:
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize images to [-1, 1]

In [33]:
# Set training parameters
BUFFER_SIZE = 2000
BATCH_SIZE = 32
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 4

In [42]:
train_dataset = tf.data.Dataset.from_tensor_slices(train_images[:BUFFER_SIZE])  # Converts the train_images array into a TensorFlow Dataset object.
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [35]:
gan = GAN(noise_dim)

In [36]:
# Create seed for generating images
seed = tf.random.normal([num_examples_to_generate, noise_dim])

In [38]:
seed.shape

TensorShape([4, 100])

In [None]:
# Training loop
for epoch in range(EPOCHS):
    for image_batch in train_dataset:  # Train for each batch.
        gen_loss, disc_loss = gan.train_step(image_batch)

    # Generate example images every 10 epochs
    if (epoch + 1) % 10 == 0:
        gan.generate_and_save_images(epoch + 1, seed)
        print(f'Epoch {epoch+1}, Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}')
