In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from tensorflow.keras import Sequential
from tensorflow.keras.datasets import mnist
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.optimizers import Adam

tfl = tf.keras.layers

In [None]:
def build_generator(latent_dim, dropout=0.4, relu_alpha=0.2, momentum=0.9, initializer=None):
    model = Sequential(name='Generator')

    shape = (7, 7, 128)
    model.add(tfl.Dense(np.prod(shape), kernel_initializer=initializer, input_dim=latent_dim))
    model.add(tfl.BatchNormalization(momentum=momentum))
    model.add(tfl.LeakyReLU(alpha=relu_alpha))
    model.add(tfl.Reshape(shape))
    model.add(tfl.Dropout(dropout))

    model.add(tfl.Conv2DTranspose(128, kernel_size=5, strides=2, padding='same', kernel_initializer=initializer))
    model.add(tfl.BatchNormalization(momentum=momentum))
    model.add(tfl.LeakyReLU(alpha=relu_alpha))

    model.add(tfl.Conv2DTranspose(64, kernel_size=5, strides=2, padding='same', kernel_initializer=initializer))
    model.add(tfl.BatchNormalization(momentum=momentum))
    model.add(tfl.LeakyReLU(alpha=relu_alpha))

    model.add(tfl.Conv2DTranspose(32, kernel_size=5, padding='same', kernel_initializer=initializer))
    model.add(tfl.BatchNormalization(momentum=momentum))
    model.add(tfl.LeakyReLU(alpha=relu_alpha))

    model.add(tfl.Conv2D(1, kernel_size=5, activation='tanh', padding='same', kernel_initializer=initializer))

    return model

In [None]:
def build_discriminator(image_shape, dropout=0.4, relu_alpha=0.2, initializer=None):
    model = Sequential(name='Discriminator')

    model.add(tfl.Conv2D(32, kernel_size=5, strides=2, padding='same', kernel_initializer=initializer, input_shape=image_shape))
    model.add(tfl.LeakyReLU(alpha=relu_alpha))

    model.add(tfl.Conv2D(64, kernel_size=5, strides=2, padding='same', kernel_initializer=initializer))
    model.add(tfl.LeakyReLU(alpha=relu_alpha))
    model.add(tfl.Dropout(dropout))

    model.add(tfl.Conv2D(128, kernel_size=5, strides=2, padding='same', kernel_initializer=initializer))
    model.add(tfl.LeakyReLU(alpha=relu_alpha))
    model.add(tfl.Dropout(dropout))

    model.add(tfl.Conv2D(256, kernel_size=5, padding='same', kernel_initializer=initializer))
    model.add(tfl.LeakyReLU(alpha=relu_alpha))
    model.add(tfl.Dropout(dropout))
    model.add(tfl.Flatten())

    model.add(tfl.Dense(1, activation='sigmoid', kernel_initializer=initializer))

    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])

    return model

In [None]:
def build_gan(generator, discriminator):
    discriminator.trainable = False

    model = Sequential(name='GAN')
    model.add(generator)
    model.add(discriminator)
    
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)

    return model

In [None]:
def load_data():
    (X_train, _), (X_test, _) = mnist.load_data()
    
    # stack train and test images
    X = np.vstack((X_train, X_test))
    
    # add channel to images (required by tensorflow)
    X = np.expand_dims(X, axis=-1)
    
    # convert images to floats
    X = X.astype('float32')
    
    # scale images to [-1,1]
    X = X / 127.5 - 1
    
    return X

# sample from latent space
def generate_latent_data(latent_dim, n_samples):
    return np.random.normal(0, 1, (n_samples, latent_dim))

# generate real data to train discriminator
def generate_real_data(data, n_samples):
    index = np.random.randint(0, data.shape[0], n_samples)
    X = data[index]
    y = np.ones((n_samples, 1))
    return X, y

# generate fake data to train discriminator
def generate_fake_data(generator, latent_dim, n_samples):
    latent = generate_latent_data(latent_dim, n_samples)
    X = generator.predict(latent)
    y = np.zeros((n_samples, 1))
    return X, y

# plot set of images at the current
def save_images(images, image_name, rows, cols):
    # scale images to [0,1]
    images = (images + 1) / 2.0

    fig, axs = plt.subplots(rows, cols, figsize=(15, 15))

    for i in range(rows):
        for j in range(cols):
            axs[i,j].imshow(images[j + i*cols, :, :, 0], cmap='gray_r')
            axs[i,j].axis('off')

    fig.savefig(f'{image_name}.png')
    plt.close()

In [None]:
def train(generator, discriminator, gan, data, latent_dim, soft_label=0.9, pct_wrong=0.05,
          epochs=1000, discriminator_batch_size=128, generator_batch_size=128,
          print_interval=10, save_interval=50, image_dir='images', image_rows=5, image_cols=5):
    # discriminator batch size for real and fake data
    half_batch = int(discriminator_batch_size / 2)
    
    # samples from latent space used to plot generator evolution
    latent = generate_latent_data(latent_dim, image_rows * image_cols)
    
    for epoch in range(epochs+1):
        # select random batch of images
        X_real, y_real = generate_real_data(data, half_batch)

        # generate batch of fake images
        X_fake, y_fake = generate_fake_data(generator, latent_dim, half_batch)
        
        # add percentage of wrong data to discriminator training data
        if pct_wrong > 0:
            num_wrong = int(pct_wrong * half_batch)
            
            # add fake data to real training data
            X, y = generate_fake_data(generator, latent_dim, num_wrong)
            X_real, y_real = np.vstack((X_real, X)), np.vstack((y_real, 1 - y))
            
            # add real data to fake training data
            X, y = generate_real_data(data, num_wrong)
            X_fake, y_fake = np.vstack((X_fake, X)), np.vstack((y_fake, 1 - y))
        
        # train discriminator
        d_loss_real, acc_real = discriminator.train_on_batch(X_real, y_real * soft_label)
        d_loss_fake, acc_fake = discriminator.train_on_batch(X_fake, y_fake)
        
        # evaluate real data without soft label
        if soft_label != 1:
            d_loss_real, acc_real = discriminator.evaluate(X_real, y_real, verbose=0)
        
        d_loss = (d_loss_real + d_loss_fake) / 2.0

        # train generator
        X_gan = generate_latent_data(latent_dim, generator_batch_size)
        y_gan = np.ones((generator_batch_size, 1))
        g_loss = gan.train_on_batch(X_gan, y_gan)

        # print progress
        if epoch % print_interval == 0:
            e = f'{epoch}/{epochs}'
            print(f'epoch {e}, D: [loss={d_loss:.3f}, real={acc_real*100:.2f}, fake={acc_fake*100:.2f}], G: [loss={g_loss:.3f}]')

        # save images generated from latent space
        if epoch % save_interval == 0:
            images = generator.predict(latent)
            image_name = f'{image_dir}/mnist_epoch_{epoch}'
            save_images(images, image_name, image_rows, image_cols)

In [None]:
image_rows = 28
image_cols = 28
image_shape = (image_rows, image_cols, 1)
latent_dim = 100
initializer = RandomNormal(mean=0.0, stddev=0.05)

In [None]:
generator = build_generator(latent_dim, initializer=initializer)
generator.summary()

In [None]:
discriminator = build_discriminator(image_shape, initializer=initializer)
discriminator.summary()

In [None]:
gan = build_gan(generator, discriminator)
gan.summary()

In [None]:
data = load_data()
data.shape

In [None]:
image_dir = 'images'
%mkdir -p "$image_dir"

In [None]:
train(
    generator, discriminator, gan, data, latent_dim, image_dir=image_dir,
    epochs=1000, discriminator_batch_size=128, generator_batch_size=200
)