In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
import numpy as np

In [8]:
#gtenerator model
def build_generator(latent_dim):
	model = models.Sequential([
			layers.Dense(128,input_dim=latent_dim),
			layers.LeakyReLU(alpha=0.2),
			layers.Dense(784,activation='tanh'),
			layers.Reshape((28,28))])
	return model
			

In [9]:
#discriminator model
def build_discriminator():
    model = models.Sequential([
        layers.Flatten(input_shape=(28,28)),
        layers.Dense(128),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1,activation='sigmoid')
    ])
    return model

In [10]:
#GAN model
def build_gan(generator,discriminator):
    model = models.Sequential([
        generator,
        discriminator
    ])
    discriminator.trainable = False  #Freeze discriminator during GAN training
    return model

In [15]:
#load mnist dataset
(train_images, _), (_,_) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape(train_images.shape[0], 28,28).astype('float32')
train_images = (train_images -127.5)/127.5  #normalize between -1 and 1


In [17]:
latent_dim = 100

generator = build_generator(latent_dim=latent_dim)
discriminator = build_discriminator()

gan = build_gan(generator, discriminator)

discriminator.compile(loss='binary_crossentropy',
                optimizer = optimizers.Adam(learning_rate=0.0002, beta_1=0.5))

gan.compile(loss='binary_crossentropy',
            optimizer = optimizers.Adam(learning_rate=0.0002,beta_1=0.5))

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(**kwargs)


In [18]:
# Training loop

epochs = 5
batch_size = 128
steps_per_epoch = train_images.shape[0] //batch_size

for epoch in range(epochs):
    for step in range(steps_per_epoch):
        #Train discriminator
        real_images = train_images[np.random.randint(0, train_images.shape[0], batch_size)]
        fake_images = generator.predict(np.random.normal(0,1,(batch_size, latent_dim)))
        discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size,1)))
        discriminator_loss_fake = discriminator.train_on_batch(fake_images, np.zeroes((batch_size,1)))
        discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)

        # Train generator
        noise = np.random.normal(0,1,(batch_size, latent_dim))
        generator_loss = gan.train_on_batch(noise, np.ones((batch_size,1)))

        #Print progress
        print(f"Epoch {epoch+1}/{epochs}, Step {step+1}/{steps_per_epoch}),"
              f"Discriminator Loss : {discriminator_loss}, Generator Loss : {generator_loss}")

[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step  




Epoch 1/5, Step 1/468),Discriminator Loss : 0.7898069620132446, Generator Loss : 0.6638312339782715
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
Epoch 1/5, Step 2/468),Discriminator Loss : 0.723846435546875, Generator Loss : 0.6275845766067505
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
Epoch 1/5, Step 3/468),Discriminator Loss : 0.6966452598571777, Generator Loss : 0.593453586101532
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 980us/step
Epoch 1/5, Step 4/468),Discriminator Loss : 0.6766347289085388, Generator Loss : 0.5550447106361389
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step 
Epoch 1/5, Step 5/468),Discriminator Loss : 0.6640108823776245, Generator Loss : 0.5207845568656921
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step 
Epoch 1/5, Step 6/468),Discriminator Loss : 0.645713746547699, Generator Loss : 0.48924651741981506
[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m

In [None]:
import matplotlib.pyplot as plt

def visualize_generated_images(generator, latent_dim , num_images=10):
    noise = np.random.normal(0,1, (num_images,latent_dim))
    generated_images = generator.predict(noise)

    plt.figure(figsize=(10,10))
    for i in range(num_images):
        plt.subplot(1,num_images, i+1)
        plt.imshow(generated_images[i], cmap='gray')
        plt.axis('off')
    plt.show()

visualize_generated_images(generator, latent_dim)