In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from DCGAN import discriminator, generator, DCGAN

In [2]:
discriminator = discriminator
generator = generator
DCGAN = DCGAN

In [3]:
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
DCGAN.compile(loss='binary_crossentropy', optimizer='adam')

In [4]:
latent_dim = 100

def sample_images(generator, image_grid_rows=4, image_grid_columns=4):
    z = np.random.normal(0, 1, (image_grid_rows*image_grid_columns, latent_dim))
    gen_imgs = generator.predict(z)
    gen_imgs = 0.5*gen_imgs + 0.5
    fig, axs = plt.subplots(image_grid_rows, image_grid_columns, figsize=(4, 4), sharex=True, sharey=True)
    cnt = 0
    for i in range(image_grid_rows):
        for j in range(image_grid_columns):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1

In [7]:
def train(iterations, batch_size, sample_interval):
    losses = []
    accuracies = []
    iteration_checkpoints = []

    (X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
    X_train = X_train/127.5 - 1.0
    X_train = np.expand_dims(X_train, 3)

    real = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))

    for iteration in range(iterations):
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs = X_train[idx]

        z = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(z)
        discriminator.trainable = True
        d_loss_real = discriminator.train_on_batch(imgs, real)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        d_loss, accuracy = 0.5*np.add(d_loss_real, d_loss_fake)

        z = np.random.normal(0, 1, (batch_size, 100))
        gen_imgs = generator.predict(z)

        discriminator.trainable = False
        g_loss = DCGAN.train_on_batch(z, real)

        if (iteration + 1) % sample_interval == 0:
            losses.append((d_loss, g_loss))
            accuracies.append(100*accuracy)
            iteration_checkpoints.append(iteration + 1)

            print(f'{iteration + 1} [D_loss: {d_loss:.4f}, accuracy: {100*accuracy:.2f}] [G_loss: {g_loss:.4f}]')
            sample_images(generator)

    return losses, accuracies, iteration_checkpoints

In [None]:
losses, accuracies, iteration_checkpoints = train(20000, 128, 1000)