In [24]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

In [25]:
img_rows = 28
img_cols = 28
channels = 1

img_shape = (img_rows, img_cols, channels)
z_dim = 100

In [26]:
def build_generator(img_shape, z_dim):
    model = Sequential()
    
    model.add(Dense(128, input_dim=z_dim))
    model.add(LeakyReLU(alpha=0.01))
    model.add(Dense(28*28*1, activation="tanh"))
    model.add(Reshape(img_shape))
    return model

In [27]:
def build_discriminator(img_shape):
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(128))
    model.add(LeakyReLU(alpha=0.01))
    model.add(Dense(1, activation="sigmoid"))
    return model

In [28]:
def build_gan(generator, discriminator):
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model

discriminator = build_discriminator(img_shape)
discriminator.compile(loss="binary_crossentropy", optimizer=Adam(), metrics=['accuracy'])

generator = build_generator(img_shape, z_dim)

discriminator.trainable = False

gan = build_gan(generator, discriminator)
gan.compile(loss="binary_crossentropy", optimizer=Adam())

In [38]:
losses = []
accuracies = []
iteration_checkpoints = []

def train(iterations, batch_size, sample_interval):
    (X_train, _), (_,_) = mnist.load_data()
    X_train = X_train / 127.5 - 1.0
    X_train = np.expand_dims(X_train, axis=3)
    
    real = np.ones((batch_size,1))
    fake = np.zeros((batch_size,1))
    
    for iteration in range(iterations):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]
        
        z = np.random.normal(0,1,(batch_size,100))
        d_loss_real = discriminator.train_on_batch(imgs, real)
        
        z = np.random.normal(0,1,(batch_size,100))
        gen_imgs = generator.predict(z)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss, accuracy = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        
        g_loss = gan.train_on_batch(z, real)
        
        if (iteration + 1) % sample_interval == 0:
            losses.append((d_loss, g_loss))
            accuracies.append(100.0 * accuracy)
            iteration_checkpoints.append(iteration + 1)
            
            print(" {} [D Loss : {}, acc : {:2f}] [G Loss : {}]".format(iteration + 1, d_loss, 100.0*accuracy, g_loss))
            sample_images(generator)

In [39]:
def sample_images(generator, img_grid_rows=4, img_grid_col=4):
    z = np.random.normal(0,1,(img_grid_rows * img_grid_col, z_dim))
    gen_imgs = generator.predict(z)
    gen_imgs = 0.5 * gen_imgs + 0.5
    fig, axs = plt.subplots(img_grid_rows, img_grid_col, figsize=(4,4), sharey=True, sharex=True)
    cnt = 0
    for i in range(img_grid_rows):
        for j in range(img_grid_col):
            axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1

In [None]:
iterations = 20000
batch_size = 1
sample_interval = 1000

train(iterations,batch_size,sample_interval)

2022-08-26 16:43:05.153797: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.




2022-08-26 16:43:06.200245: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


 1000 [D Loss : 0.5001898407936096, acc : 50.000000] [G Loss : 1.9069288969039917]
 2000 [D Loss : 0.2025152277201414, acc : 100.000000] [G Loss : 1.3436298370361328]
 3000 [D Loss : 1.5268458724021912, acc : 50.000000] [G Loss : 1.3509421348571777]
 4000 [D Loss : 0.8965376615524292, acc : 50.000000] [G Loss : 0.6804546117782593]
 5000 [D Loss : 0.157100360840559, acc : 100.000000] [G Loss : 1.6843806505203247]
 6000 [D Loss : 0.34096622467041016, acc : 100.000000] [G Loss : 1.2097841501235962]
 7000 [D Loss : 0.5165829509496689, acc : 100.000000] [G Loss : 0.7824682593345642]
 8000 [D Loss : 1.0752598084509373, acc : 50.000000] [G Loss : 0.2523835599422455]
