<a href="https://colab.research.google.com/github/nmanohar40693/MSAI-531-A02/blob/main/gan_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
print("TensorFlow version:", tf.__version__)

TensorFlow version: 2.19.0


In [None]:
from tensorflow.keras import layers, models, optimizers
import numpy as np
import matplotlib.pyplot as plt
import os


In [None]:
# Load MNIST dataset
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()

# Normalize images to [-1, 1]
x_train = (x_train.astype(np.float32) - 127.5) / 127.5

# Add channel dimension
x_train = np.expand_dims(x_train, axis=-1)

print("Training data shape:", x_train.shape)


Training data shape: (60000, 28, 28, 1)


In [None]:
BUFFER_SIZE = x_train.shape[0]
BATCH_SIZE = 128

dataset = tf.data.Dataset.from_tensor_slices(x_train)
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

print("Dataset ready")


Dataset ready


In [None]:
def build_generator():
    model = models.Sequential(name="Generator")

    model.add(layers.Dense(256, input_dim=100))
    model.add(layers.LeakyReLU(alpha=0.2))

    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU(alpha=0.2))

    model.add(layers.Dense(1024))
    model.add(layers.LeakyReLU(alpha=0.2))

    model.add(layers.Dense(28 * 28, activation='tanh'))
    model.add(layers.Reshape((28, 28, 1)))

    return model


In [None]:
generator = build_generator()
generator.summary()


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [None]:
def build_discriminator():
    model = models.Sequential(name="Discriminator")

    model.add(layers.Flatten(input_shape=(28, 28, 1)))

    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU(alpha=0.2))

    model.add(layers.Dense(256))
    model.add(layers.LeakyReLU(alpha=0.2))

    model.add(layers.Dense(1, activation='sigmoid'))

    return model


In [None]:
discriminator = build_discriminator()
discriminator.summary()


  super().__init__(**kwargs)


In [None]:
discriminator.compile(
    loss='binary_crossentropy',
    optimizer=optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
    metrics=['accuracy']
)


In [None]:
discriminator.trainable = False


In [None]:
gan_input = layers.Input(shape=(100,))
generated_image = generator(gan_input)
gan_output = discriminator(generated_image)

gan = models.Model(gan_input, gan_output, name="GAN")

gan.compile(
    loss='binary_crossentropy',
    optimizer=optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
)

gan.summary()


In [None]:
def save_generated_images(epoch, generator, examples=25):
    noise = np.random.normal(0, 1, (examples, 100))
    generated_images = generator.predict(noise, verbose=0)

    # Rescale from [-1, 1] to [0, 1]
    generated_images = 0.5 * generated_images + 0.5

    fig = plt.figure(figsize=(5, 5))
    for i in range(examples):
        plt.subplot(5, 5, i + 1)
        plt.imshow(generated_images[i, :, :, 0], cmap='gray')
        plt.axis('off')

    os.makedirs("gan_images", exist_ok=True)
    plt.savefig(f"gan_images/epoch_{epoch}.png")
    plt.close()


In [None]:
save_generated_images(999, generator)

In [None]:
EPOCHS = 100          # keep small for now
LATENT_DIM = 100

for epoch in range(EPOCHS):
    for real_images in dataset:
        batch_size = real_images.shape[0]

        # =========================
        # 1. Train Discriminator
        # =========================
        discriminator.trainable = True

        noise = np.random.normal(0, 1, (batch_size, LATENT_DIM))
        fake_images = generator.predict(noise, verbose=0)

        real_labels = np.ones((batch_size, 1))
        fake_labels = np.zeros((batch_size, 1))

        discriminator.train_on_batch(real_images, real_labels)
        discriminator.train_on_batch(fake_images, fake_labels)

        # =========================
        # 2. Train Generator
        # =========================
        discriminator.trainable = False

        noise = np.random.normal(0, 1, (batch_size, LATENT_DIM))
        misleading_labels = np.ones((batch_size, 1))

        gan.train_on_batch(noise, misleading_labels)

    # save images every 5 epochs
    if epoch % 5 == 0:
        save_generated_images(epoch, generator)
        print(f"Saved images at epoch {epoch}")


Saved images at epoch 0
Saved images at epoch 5
Saved images at epoch 10
Saved images at epoch 15
Saved images at epoch 20
Saved images at epoch 25
Saved images at epoch 30
Saved images at epoch 35
Saved images at epoch 40
Saved images at epoch 45
Saved images at epoch 50
Saved images at epoch 55
Saved images at epoch 60
Saved images at epoch 65
Saved images at epoch 70
Saved images at epoch 75
Saved images at epoch 80
Saved images at epoch 85
Saved images at epoch 90
Saved images at epoch 95


In [None]:
!ls /content/gan_images


ls: cannot access '/content/gan_images': No such file or directory
