In [25]:
import os
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers
import matplotlib.pyplot as plt

In [26]:
# Define your path
IMG_DIR = r"C:\Users\mds60\OneDrive - hamilton.edu\Semester 8\Statistical Methods in Machine Learning\Project\img_align_celeba_small"
BATCH_SIZE = 64
IMG_SIZE = 64  # Final square resolution
LATENT_DIM = 100

In [32]:
# Preprocessing function
def preprocess_image(file_path):
    # Read image from disk
    img = tf.io.read_file(file_path)
    img = tf.io.decode_jpeg(img, channels=3)
    
    # Normalize to [-1, 1]
    img = (tf.cast(img, tf.float32) / 127.5) - 1.0
    
    return img

# Create dataset
dataset = (
    tf.data.Dataset.list_files(os.path.join(IMG_DIR, "*.jpg"), shuffle=True)
    .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    .cache()
    .shuffle(buffer_size=10000)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)


In [37]:
# -------------------------------------------------------------------------
# 2) GENERATOR for 64×64
# -------------------------------------------------------------------------
def build_generator():
    model = tf.keras.Sequential(name="Generator")
    # Dense → 8×8×256
    model.add(layers.Input(shape=(LATENT_DIM,)))
    model.add(layers.Dense(8*8*256, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())
    model.add(layers.Reshape((8, 8, 256)))  # → (8,8,256)

    # 8→16
    model.add(layers.Conv2DTranspose(128, 5, strides=2, padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())

    # 16→32
    model.add(layers.Conv2DTranspose(64, 5, strides=2, padding="same", use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())

    # 32→64  ← final upsample to your target size
    model.add(layers.Conv2DTranspose(3,   # RGB output
                                     5,
                                     strides=2,
                                     padding="same",
                                     use_bias=False,
                                     activation="tanh"))
    return model

# -------------------------------------------------------------------------
# 3) DISCRIMINATOR for 64×64
# -------------------------------------------------------------------------
def build_discriminator():
    model = tf.keras.Sequential(name="Discriminator")
    model.add(layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3)))

    # 64→32
    model.add(layers.Conv2D(64, 5, strides=2, padding="same"))
    model.add(layers.LeakyReLU(0.2))
    model.add(layers.Dropout(0.3))

    # 32→16
    model.add(layers.Conv2D(128, 5, strides=2, padding="same"))
    model.add(layers.LeakyReLU(0.2))
    model.add(layers.Dropout(0.3))

    # 16→8
    model.add(layers.Conv2D(256, 5, strides=2, padding="same"))
    model.add(layers.LeakyReLU(0.2))
    model.add(layers.Dropout(0.3))

    # 8→4  (optional; you can stop here)
    model.add(layers.Conv2D(512, 5, strides=2, padding="same"))
    model.add(layers.LeakyReLU(0.2))
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))   # output logits
    return model

# -------------------------------------------------------------------------
# 4) INSTANTIATE
# -------------------------------------------------------------------------
generator     = build_generator()
discriminator = build_discriminator()


In [38]:
# -----------------------------------------------------------------------------
# 3) LOSS, OPTIMIZERS & METRICS
# -----------------------------------------------------------------------------
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_logits, fake_logits):
    real_loss = cross_entropy(tf.ones_like(real_logits), real_logits)
    fake_loss = cross_entropy(tf.zeros_like(fake_logits), fake_logits)
    return real_loss + fake_loss

def generator_loss(fake_logits):
    return cross_entropy(tf.ones_like(fake_logits), fake_logits)

gen_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
disc_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# track metrics
d_loss_metric = tf.keras.metrics.Mean(name="d_loss")
g_loss_metric = tf.keras.metrics.Mean(name="g_loss")
d_accuracy   = tf.keras.metrics.BinaryAccuracy(name="d_accuracy")

In [39]:
# -----------------------------------------------------------------------------
# 4) TRAIN STEP
# -----------------------------------------------------------------------------
@tf.function
def train_step(real_images):
    batch_size = tf.shape(real_images)[0]
    noise = tf.random.normal([batch_size, LATENT_DIM])

    with tf.GradientTape() as gt_gen, tf.GradientTape() as gt_disc:
        fake_images = generator(noise, training=True)

        real_logits = discriminator(real_images, training=True)
        fake_logits = discriminator(fake_images, training=True)

        g_loss = generator_loss(fake_logits)
        d_loss = discriminator_loss(real_logits, fake_logits)

    grads_gen  = gt_gen.gradient(g_loss, generator.trainable_variables)
    grads_disc = gt_disc.gradient(d_loss, discriminator.trainable_variables)

    gen_optimizer.apply_gradients(zip(grads_gen,  generator.trainable_variables))
    disc_optimizer.apply_gradients(zip(grads_disc, discriminator.trainable_variables))

    # update metrics
    g_loss_metric(g_loss)
    d_loss_metric(d_loss)
    d_accuracy.update_state(tf.ones_like(real_logits), real_logits)
    d_accuracy.update_state(tf.zeros_like(fake_logits), fake_logits)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

# Ensure output directory exists
os.makedirs("checkpoints", exist_ok=True)
os.makedirs("generated_images", exist_ok=True)

# Set random seed for consistent image generation
fixed_seed = tf.random.normal([16, LATENT_DIM])  # 4x4 grid

# Training
g_losses = []
d_losses = []
d_accs = []
EPOCHS = 50
WEIGHT_SAVE_FREQ = 1
PICTURE_GENERATION_FREQ = 1

def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)

    # Scale pixel values back from [-1, 1] to [0, 1]
    predictions = (predictions + 1.0) / 2.0

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i])
        plt.axis("off")

    plt.suptitle(f"Epoch {epoch}")
    plt.tight_layout()
    plt.savefig(f"generated_images/image_epoch_{epoch:03d}.png")
    plt.close()

# Training loop
for epoch in range(1, EPOCHS + 1):
    g_loss_metric.reset_states()
    d_loss_metric.reset_states()
    d_accuracy.reset_states()

    for batch in dataset:
        train_step(batch)

    g_losses.append(float(g_loss_metric.result().numpy()))
    d_losses.append(float(d_loss_metric.result().numpy()))
    d_accs.append(float(d_accuracy.result().numpy()))

    print(f"Epoch {epoch:03d}  G_loss={g_losses[-1]:.4f}  "
          f"D_loss={d_losses[-1]:.4f}  D_acc={d_accs[-1]:.4f}")

    # Save model weights
    if epoch % WEIGHT_SAVE_FREQ == 0:
        generator.save_weights(f"checkpoints/generator_epoch_{epoch:03d}.h5")
        discriminator.save_weights(f"checkpoints/discriminator_epoch_{epoch:03d}.h5")

    # Save generated images
    if epoch % PICTURE_GENERATION_FREQ == 0:
        generate_and_save_images(generator, epoch, fixed_seed)


Epoch 001  G_loss=1.7115  D_loss=0.7919  D_acc=0.8175
Epoch 002  G_loss=1.5269  D_loss=0.8676  D_acc=0.7884
Epoch 003  G_loss=1.5780  D_loss=0.8534  D_acc=0.7954
Epoch 004  G_loss=1.6547  D_loss=0.8314  D_acc=0.8028
Epoch 005  G_loss=1.8187  D_loss=0.7555  D_acc=0.8259
Epoch 006  G_loss=1.9327  D_loss=0.7228  D_acc=0.8361
Epoch 007  G_loss=2.0155  D_loss=0.6992  D_acc=0.8433
Epoch 008  G_loss=1.9667  D_loss=0.7641  D_acc=0.8297


In [None]:
plt.figure()
plt.plot(range(1, EPOCHS + 1), g_losses)
plt.plot(range(1, EPOCHS + 1), d_losses)
plt.title("Generator and Discriminator Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["G_loss", "D_loss"])
plt.show()

# 4. Plot Discriminator Accuracy
plt.figure()
plt.plot(range(1, EPOCHS + 1), d_accs)
plt.title("Discriminator Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.show()