<a href="https://colab.research.google.com/github/kiarashkh/GAN_cat_image_generation/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow.keras import layers

def make_generator(latent_dim=100):
    model = tf.keras.Sequential()

    # Step 1: Dense layer to project noise into a feature map
    model.add(layers.Dense(8*8*256, use_bias=False, input_shape=(latent_dim,)))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())
    model.add(layers.Reshape((8, 8, 256)))  # (H, W, C)

    # Step 2: Transposed Conv → upsample to 16x16
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())

    # Step 3: Transposed Conv → upsample to 32x32
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())

    # Step 4: Transposed Conv → upsample to 64x64 with 3 channels (RGB)
    model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding="same", use_bias=False, activation="tanh"))

    return model





In [4]:
def make_discriminator(image_shape=(64, 64, 3)):
    model = tf.keras.Sequential()

    # Step 1: Convolution block → downsample 64x64 -> 32x32
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding="same", 
                            input_shape=image_shape))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.3))  # helps prevent overfitting

    # Step 2: Downsample 32x32 -> 16x16
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding="same"))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.3))

    # Step 3: Downsample 16x16 -> 8x8
    model.add(layers.Conv2D(256, (5, 5), strides=(2, 2), padding="same"))
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.3))

    # Step 4: Flatten + final dense
    model.add(layers.Flatten())
    model.add(layers.Dense(1))  # output: logit (no sigmoid if using from_logits=True)

    return model


In [5]:
generator = make_generator(latent_dim=100)
discriminator = make_discriminator()






In [6]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# Loss functions
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)   # real → 1
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) # fake → 0
    return real_loss + fake_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)  # trick D into thinking fake=real


In [9]:
# Typical DCGAN-style hyperparameters
LR = 2e-4          # learning rate (0.0002)
BETA_1 = 0.5       # Adam beta1 (important for GAN stability)

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=LR, beta_1=BETA_1)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=LR, beta_1=BETA_1)


In [10]:
@tf.function
def train_step(real_images):
    noise = tf.random.normal([BATCH_SIZE, LATENT_DIM])

    # 1) compute both losses & gradients
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        fake_images = generator(noise, training=True)

        real_logits = discriminator(real_images, training=True)
        fake_logits = discriminator(fake_images, training=True)

        gen_loss = generator_loss(fake_logits)
        disc_loss = discriminator_loss(real_logits, fake_logits)

    # 2) gradients
    gen_grads = gen_tape.gradient(gen_loss, generator.trainable_variables)
    disc_grads = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    # 3) optional: gradient clipping to stabilize
    gen_grads, _ = tf.clip_by_global_norm(gen_grads, 5.0)
    disc_grads, _ = tf.clip_by_global_norm(disc_grads, 5.0)

    # 4) apply gradients with the optimizers we defined
    generator_optimizer.apply_gradients(zip(gen_grads, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(disc_grads, discriminator.trainable_variables))

    return gen_loss, disc_loss


In [None]:
IMG_SIZE = 64   # we want 64x64 for DCGAN

def preprocess_image(path):
    # Read file from disk
    img = tf.io.read_file(path)
    # Decode JPEG (or PNG), keep 3 channels (RGB)
    img = tf.image.decode_jpeg(img, channels=3)
    # Resize to target size
    img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
    # Scale from [0,255] → [-1,1] (tanh expects this)
    img = (img - 127.5) / 127.5
    return img


In [None]:
BATCH_SIZE = 64

# List all JPEGs inside ./cats
cat_paths = tf.data.Dataset.list_files("./cats/*.jpg", shuffle=True)

# Map preprocessing
cat_dataset = cat_paths.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)

# Shuffle, batch, prefetch
cat_dataset = (cat_dataset
               .shuffle(buffer_size=1000)
               .batch(BATCH_SIZE, drop_remainder=True)
               .prefetch(tf.data.AUTOTUNE))


In [None]:
import matplotlib.pyplot as plt

for batch in cat_dataset.take(1):
    print(batch.shape)  # (64, 64, 64, 3) → batch of 64 images
    plt.imshow((batch[0] + 1) / 2.0)  # convert back from [-1,1] → [0,1]
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def show_real_and_fake(generator, dataset, epoch, latent_dim=100):
    # Get one batch of real images
    real_batch = next(iter(dataset))
    
    # Take 2 real images
    real_images = real_batch[:2]

    # Generate 2 fake images
    noise = tf.random.normal([2, latent_dim])
    fake_images = generator(noise, training=False)

    # Convert from [-1,1] → [0,1] for display
    real_images = (real_images + 1) / 2.0
    fake_images = (fake_images + 1) / 2.0

    fig, axes = plt.subplots(2, 2, figsize=(6, 6))
    axes = axes.flatten()

    # Show 2 real
    for i in range(2):
        axes[i].imshow(real_images[i].numpy())
        axes[i].set_title("Real")
        axes[i].axis("off")

    # Show 2 fake
    for i in range(2):
        axes[i+2].imshow(fake_images[i].numpy())
        axes[i+2].set_title("Fake")
        axes[i+2].axis("off")

    plt.suptitle(f"Epoch {epoch}")
    plt.show()


In [None]:
EPOCHS = 20

for epoch in range(1, EPOCHS+1):
    for real_batch in cat_dataset:
        g_loss, d_loss = train_step(real_batch)

    print(f"Epoch {epoch} | Generator loss: {g_loss:.4f}, Discriminator loss: {d_loss:.4f}")

    # At the end of epoch: show 2 real + 2 fake
    show_real_and_fake(generator, cat_dataset, epoch, latent_dim=100)
