In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from WGAN import critic, generator, WGAN

In [None]:
latent_dim = 100

def sample_images(generator, image_grid_rows=4, image_grid_columns=4):
    z = np.random.normal(0, 1, (image_grid_rows*image_grid_columns, latent_dim))
    gen_imgs = generator.predict(z)
    gen_imgs = 0.5*gen_imgs + 0.5
    fig, axs = plt.subplots(image_grid_rows, image_grid_columns, figsize=(4, 4), sharex=True, sharey=True)
    cnt = 0
    for i in range(image_grid_rows):
        for j in range(image_grid_columns):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1
    plt.show()

In [None]:
n_critic = 5
clip_value = 0.01

def train(epochs, batch_size=128, sample_interval=50):
    (X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)

    valid = -np.ones((batch_size, 1))
    fake = np.ones((batch_size, 1))

    for epoch in range(epochs):

        for _ in range(n_critic):
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
                
            z = np.random.normal(0, 1, (batch_size, latent_dim))
            gen_imgs = generator.predict(z)

            d_loss_real = critic.train_on_batch(imgs, valid)
            d_loss_fake = critic.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

            for l in critic.layers:
                weights = l.get_weights()
                weights = [np.clip(w, -clip_value, clip_value) for w in weights]
                l.set_weights(weights)

        g_loss = WGAN.train_on_batch(z, valid)
            
        print (f"{epoch} [D loss: {1 - d_loss[0]}] [G loss: {1 - g_loss[0]}]")

        if epoch % sample_interval == 0:
            sample_images(generator)

In [None]:
train(4000, 32, 50)