In [1]:
# Install TensorFlow if not present (Colab has it by default)
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, LeakyReLU, BatchNormalization, Input, Embedding, multiply
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.datasets import mnist

print("TensorFlow version:", tf.__version__)

TensorFlow version: 2.18.0


In [2]:
# Load MNIST data
(X_train, y_train), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5  # Normalize to [-1, 1]
X_train = np.expand_dims(X_train, axis=-1)
num_classes = 10
latent_dim = 100
img_shape = (28, 28, 1)


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [4]:
def build_generator():
    noise = Input(shape=(latent_dim,))
    label = Input(shape=(1,), dtype='int32')
    label_embedding = Flatten()(Embedding(num_classes, latent_dim)(label))
    model_input = multiply([noise, label_embedding])
    x = Dense(256)(model_input)
    x = LeakyReLU(negative_slope=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(512)(x)
    x = LeakyReLU(negative_slope=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(1024)(x)
    x = LeakyReLU(negative_slope=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(np.prod(img_shape), activation='tanh')(x)
    img = Reshape(img_shape)(x)
    return Model([noise, label], img)

def build_discriminator():
    img = Input(shape=img_shape)
    label = Input(shape=(1,), dtype='int32')
    label_embedding = Flatten()(Embedding(num_classes, np.prod(img_shape))(label))
    flat_img = Flatten()(img)
    model_input = multiply([flat_img, label_embedding])
    x = Dense(512)(model_input)
    x = LeakyReLU(negative_slope=0.2)(x)
    x = Dense(512)(x)
    x = LeakyReLU(negative_slope=0.2)(x)
    x = Dense(512)(x)
    x = LeakyReLU(negative_slope=0.2)(x)
    validity = Dense(1, activation='sigmoid')(x)
    return Model([img, label], validity)

generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])


In [5]:
noise = Input(shape=(latent_dim,))
label = Input(shape=(1,))
img = generator([noise, label])
discriminator.trainable = False
valid = discriminator([img, label])
combined = Model([noise, label], valid)
combined.compile(loss='binary_crossentropy', optimizer='adam')


In [6]:
epochs = 5000  # You can increase this if you have more time/resources
batch_size = 64
save_interval = 1000

for epoch in range(epochs):
    # Train discriminator
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    imgs, labels = X_train[idx], y_train[idx]
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    gen_labels = np.random.randint(0, num_classes, batch_size).reshape(-1, 1)
    gen_imgs = generator.predict([noise, gen_labels], verbose=0)
    d_loss_real = discriminator.train_on_batch([imgs, labels], np.ones((batch_size, 1)))
    d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels], np.zeros((batch_size, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    # Train generator
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    sampled_labels = np.random.randint(0, num_classes, batch_size).reshape(-1, 1)
    g_loss = combined.train_on_batch([noise, sampled_labels], np.ones((batch_size, 1)))
    if epoch % save_interval == 0:
        print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")




0 [D loss: 0.6956, acc.: 37.11%] [G loss: 0.6948]
1000 [D loss: 0.7571, acc.: 13.29%] [G loss: 0.5842]
2000 [D loss: 0.7629, acc.: 12.92%] [G loss: 0.5743]
3000 [D loss: 0.7652, acc.: 12.84%] [G loss: 0.5705]
4000 [D loss: 0.7665, acc.: 12.76%] [G loss: 0.5684]


In [7]:
generator.save('mnist_gan_generator.h5')
print("Generator model saved as mnist_gan_generator.h5")



Generator model saved as mnist_gan_generator.h5


In [8]:
from google.colab import files
files.download('mnist_gan_generator.h5')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>