In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, LeakyReLU, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt
import os

In [None]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Scale data to range [-1, 1]
x_train = x_train / 255.0 * 2 - 1
x_test = x_test / 255.0 * 2 - 1
print('x_train.shape:', x_train.shape)


x_train.shape: (60000, 28, 28)


In [None]:
N, H, W = x_train.shape
D = H * W
x_train = x_train.reshape(-1, D)
x_test = x_test.reshape(-1, D)

# Define the latent dimension
latent_dim = 100

In [None]:
def build_generator(latent_dim):
    i = Input(shape=(latent_dim,))
    x = Dense(256)(i)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(512)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(1024)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(D, activation="tanh")(x)
    model = Model(i, x)
    return model


In [None]:
def build_discriminator(img_size):
    i = Input(shape=(img_size,))
    x = Dense(512)(i)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dense(256)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dense(1, activation='sigmoid')(x)
    model = Model(i, x)
    return model


In [None]:
discriminator = build_discriminator(D)
discriminator.compile(
    loss='binary_crossentropy',
    optimizer=Adam(0.0002, 0.5)
)

# Build generator
generator = build_generator(latent_dim)

# Combine models
z = Input(shape=(latent_dim,))
img = generator(z)
discriminator.trainable = False  # Freeze discriminator during generator training
fake_pred = discriminator(img)
combined_model = Model(z, fake_pred)

# Compile combined model
combined_model.compile(
    loss='binary_crossentropy',
    optimizer=Adam(0.0002, 0.5)
)


In [None]:
batch_size = 32
epochs = 30000
sample_period = 200

# Labels for training
ones = np.ones((batch_size, 1))  # Correctly shaped ones array
zeros = np.zeros((batch_size, 1))  # Correctly shaped zeros array

# Lists for losses
d_losses = []
g_losses = []

# Create directory for generated images
if not os.path.exists('gan_images'):
    os.makedirs('gan_images')

In [None]:
def sample_images(epoch):
    rows, cols = 5, 5
    noise = np.random.randn(rows * cols, latent_dim)
    img = generator.predict(noise)  # Fixed `img` generation
    imgs = 0.5 * img + 0.5  # Rescale images to range [0, 1]
    fig, axs = plt.subplots(rows, cols)
    idx = 0
    for i in range(rows):
        for j in range(cols):
            axs[i, j].imshow(imgs[idx].reshape(H, W), cmap='gray')
            axs[i, j].axis('off')
            idx += 1
    fig.savefig(f"gan_images/{epoch}.png")  # Fixed filename format
    plt.close()


In [None]:
# Training loop
for epoch in range(epochs):
    # Train Discriminator
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_imgs = x_train[idx]

    noise = np.random.randn(batch_size, latent_dim)
    fake_imgs = generator.predict(noise)

    d_loss_real = discriminator.train_on_batch(real_imgs, ones)
    d_loss_fake = discriminator.train_on_batch(fake_imgs, zeros)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # Train Generator
    noise = np.random.randn(batch_size, latent_dim)
    g_loss = combined_model.train_on_batch(noise, ones)

    # Append losses
    d_losses.append(d_loss)
    g_losses.append(g_loss)

    # Print progress
if epoch % 100 == 0:
    # Access the scalar value from g_loss if it's a list or array
    g_loss_scalar = g_loss[0] if isinstance(g_loss, (list, np.ndarray)) else g_loss
    print(f"epoch: {epoch}/{epochs}, d_loss: {d_loss:.2f}, g_loss: {g_loss_scalar:.2f}")

    # Save images at intervals
    if epoch % sample_period == 0:
        sample_images(epoch)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 33ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25

In [None]:
!ls gan_images

In [None]:
from skimage.io import imread
import matplotlib.pyplot as plt
import os

for epoch in [0, 1000, 2000, 3000]:
    img_path = f"gan_images/{epoch}.png"
    if os.path.exists(img_path):
        a = imread(img_path)
        plt.imshow(a)
        plt.axis('off')
        plt.show()