# Generative Adversarial Network

In [1]:
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy import vstack
from numpy.random import randn
from numpy.random import randint
from keras.datasets.mnist import load_data
from tensorflow.keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers import Flatten
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Dropout
from matplotlib import pyplot

2024-06-16 15:06:09.727464: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-16 15:06:09.728816: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-16 15:06:09.756186: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-06-16 15:06:09.756694: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Discriminator Model

In [2]:
def define_discriminator(in_shape=(28,28,1)):
    model = Sequential()
    model.add(Conv2D(64, (3,3), strides=(2,2), padding="same", input_shape=in_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Conv2D(64, (3,3), strides=(2,2), padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Flatten())
    model.add(Dense(1,activation="sigmoid"))

    opt = Adam(lr=0.0002,beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

## Generator Model

In [3]:
def define_generator(latent_dim):
    model = Sequential()
    n_node = 128 *7 *7
    model.add(Dense(n_node, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7,7,128)))

    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding="same"))
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2D(1, (7,7), activation="sigmoid", padding="same"))
    return model

## GAN Function

In [4]:
def define_gan(g_model, d_model):
    d_model.trainable = False

    model = Sequential()
    model.add(g_model)
    model.add(d_model)

    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

## Helper Functions

In [5]:
# load and prepare mnist training images
def load_real_samples():
    (trainX,_),(_,_) = load_data()
    X = expand_dims(trainX, axis = -1)
    X=X.astype('float32')
    X=X/255.0
    return X

# select real samples
def generate_real_samples(dataset, n_samples):
    ix = randint(0, dataset.shape[0], n_samples)
    X = dataset[ix]
    y = ones((n_samples,1))
    return X,y

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim,n_samples):
    x_input = randn(latent_dim * n_samples)
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(g_model,latent_dim,n_samples):
    x_input = generate_latent_points(latent_dim,n_samples)
    X = g_model.predict(x_input)
    y = zeros((n_samples,1))
    return X,y

# create and save a plot of generated images (reversed grayscale)
def save_plot(examples, epoch, n=10):
    for i in range(n * n):
        pyplot.subplot(n,n,1+i)
        pyplot.axis('off')
        pyplot.imshow(examples[i,:,:,0], cmap='gray_r')
    
    filename = 'generated_plot_ep_%03d.png'%(epoch+1)
    pyplot.savegif(filename)
    pyplot.close()

# evaluate the discriminator, plot generated images, save generator model
def summarize_performance(epoch, g_model, d_model, dataset, latent_dim, n_samples=100):
    X_real, y_real = generate_real_samples(dataset, n_samples)
    _, acc_real = d_model.evaluate(X_real,y_real, verbose=0)
    X_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
    _, acc_fake = d_model.evaluate(X_fake, y_fake, verbose=0)

    print("Accuracy real: %.0f%%, fake: %.0f%%" % (acc_real*100, acc_fake*100))
    save_plot(X_fake, epoch)

    filename = 'generator_model_%03d.h5' %(epoch+1)
    g_model.save(filename)

## Training 

In [6]:
# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=100, n_batch=256):
    bat_per_epoch = int(dataset.shape[0]/n_batch)
    half_batch = int(n_batch/2)

    for i in range(n_epochs):
        for j in range(bat_per_epoch):

            X_real, y_real = generate_real_samples(dataset, half_batch)
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            X,y = vstack((X_real, X_fake)), vstack((y_real, y_fake))
            d_loss,_ = d_model.train_on_batch(X,y)

            X_gan = generate_latent_points(latent_dim, n_batch)
            y_gan = ones((n_batch,1))

            g_loss = gan_model.train_on_batch(X_gan, y_gan)

            print(">epoch%d, %d/%d, d_loss=%.3f, g_loss=%.3f" %(i+1,j+1,bat_per_epoch,d_loss,g_loss))
        
        if(i+1) % 10 ==10:
            summarize_performance(i, g_model, d_model, dataset, latent_dim)
    

In [7]:
latent_dim = 100
d_model = define_discriminator()
g_model = define_generator(latent_dim)
gan_model = define_gan(g_model, d_model)
dataset = load_real_samples()

train(g_model, d_model, gan_model, dataset, latent_dim)



Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
>epoch1, 1/234, d_loss=0.699, g_loss=0.804
>epoch1, 2/234, d_loss=0.665, g_loss=0.903
>epoch1, 3/234, d_loss=0.642, g_loss=0.955
>epoch1, 4/234, d_loss=0.642, g_loss=0.884
>epoch1, 5/234, d_loss=0.661, g_loss=0.749
>epoch1, 6/234, d_loss=0.656, g_loss=0.695
>epoch1, 7/234, d_loss=0.618, g_loss=0.690
>epoch1, 8/234, d_loss=0.573, g_loss=0.694
>epoch1, 9/234, d_loss=0.524, g_loss=0.700
>epoch1, 10/234, d_loss=0.484, g_loss=0.706
>epoch1, 11/234, d_loss=0.444, g_loss=0.712
>epoch1, 12/234, d_loss=0.408, g_loss=0.721
>epoch1, 13/234, d_loss=0.379, g_loss=0.732
>epoch1, 14/234, d_loss=0.359, g_loss=0.747
>epoch1, 15/234, d_loss=0.351, g_loss=0.767
>epoch1, 16/234, d_loss=0.330, g_loss=0.797
>epoch1, 17/234, d_loss=0.313, g_loss=0.841
>epoch1, 18/234, d_loss=0.296, g_loss=0.899
>epoch1, 19/234, d_loss=0.274, g_loss=0.986
>epoch1, 20/234, d_loss=0.247, g_loss=1.105
>epoch1, 21/234, d_loss=0.213, g_loss

KeyboardInterrupt: 