In [3]:
import tensorflow as tf
# leaky relu and batch normalization often used with GANs
from tensorflow.keras.layers import Input, Dense, LeakyReLU, Dropout, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD, Adam
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import sys, os

In [2]:
# Load in the data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Map inputs to (-1, +1) for better training
x_train, x_test = x_train / 255.0 * 2 - 1, x_test / 255.0 * 2 - 1
print("x_train.shape:", x_train.shape)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
x_train.shape: (60000, 28, 28)


In [4]:
# Flatten the data to be tabular
N, H, W = x_train.shape
D = H * W
x_train = x_train.reshape(-1, D)
x_test = x_test.reshape(-1, D)

In [8]:
# Dimensionality of the latent space (hyperparameter)
latent_dim = 100

In [5]:
# Get the generator model. Takes a vector from the latent space
# Generator (first part of GAN) takes noise and turns it into images (reverse CNN)
def build_generator(latent_dim):
  i = Input(shape=(latent_dim,))
  x = Dense(256, activation=LeakyReLU(alpha=0.2))(i)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(512, activation=LeakyReLU(alpha=0.2))(x)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(1024, activation=LeakyReLU(alpha=0.2))(x)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(D, activation='tanh')(x) # tanh since pixels are centered around (-1, +1)

  model = Model(i, x)
  return model

In [6]:
# Get the discriminator model
# Discriminator (second part of GAN) is responsible for discriminating real images from fake ones
def build_discriminator(img_size):
  i = Input(shape=(img_size,))
  x = Dense(512, activation=LeakyReLU(alpha=0.2))(i)
  x = Dense(256, activation=LeakyReLU(alpha=0.2))(x)
  x = Dense(1, activation='sigmoid')(x) # binary classification (real or fake) -> sigmoid

  model = Model(i, x)
  return model

In [9]:
# Compile both models in preparation for training

# Build and compile discriminator
discriminator = build_discriminator(D)
discriminator.compile(
    loss='binary_crossentropy',
    optimizer=Adam(0.0002, 0.5),
    metrics=['accuracy']
)

# Build and compile the combined model
generator = build_generator(latent_dim)

# Create an input to represent noise sample from latent space
z = Input(shape=(latent_dim,))

# Pass noise through generator to get an image
img = generator(z)

# Make sure only the generator is trained (freeze weights of discriminator)
discriminator.trainable = False

# The true output is fake, but we label them real
fake_pred = discriminator(img)

# Create the combined model object
combined_model = Model(z, fake_pred)

# Compile combined model (flip the labels during training later)
combined_model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

In [12]:
# Train the GAN

# Config
batch_size = 32
epochs = 30000
sample_period = 200 # generate a sample every 200 steps

# Create batch labels to use when calling train_on_batch (so we dont make it over and over in the loop later)
ones = np.ones(batch_size)
zeros = np.zeros(batch_size)

# Store the losses
d_losses, g_losses = [], []

# Create a folder to store generated images
if not os.path.exists('gan_images'):
  os.makedirs('gan_images')

In [11]:
# Generate a grid of random samples from the generator and save them to a file
def sample_images(epoch):
  rows, cols = 5, 5 # 25 images total
  # Generator noise vectors from the latent space
  noise = np.random.randn(rows * cols, latent_dim) # output is size 25 x 100
  imgs = generator.predict(noise) # get our 25 generated samples/images

  # Rescale images to 0 - 1
  imgs = 0.5 * imgs + 0.5

  fig, axs = plt.subplots(rows, cols)
  idx = 0
  for i in range(rows):
    for j in range(cols):
      axs[i,j].imshow(imgs[idx].reshape(H, W), cmap='gray')
      axs[i,j].axis('off') # so we dont see lines in the plots
  fig.savefig("gan_images/%d.png" % epoch)
  plt.close() # cleanup resources

In [None]:
# Main training loop (alternate between the 2 parts of the GAN)
for epoch in range(epochs):
  # ~~~ TRAIN DISCRIMINATOR ~~~

  # Need real and fake images. Select random batch of real images
  idx = np.random.randint(0, x_train.shape[0], batch_size) # random indices from 0 to # of training samples
  real_imgs = x_train[idx]

  # Generate fake images (using noise sampled from the standard normal of the latent space)
  noise = np.random.randn(batch_size, latent_dim) # 32 x 100
  fake_imgs = generator.predict(noise)

  # Train it. Both loss and accuracy are returned
  d_loss_real, d_acc_real = discriminator.train_on_batch(real_imgs, ones)
  d_loss_fake, d_acc_fake = discriminator.train_on_batch(fake_imgs, zeros)
  # To get overall loss and accuracy, take mean of these losses and accuracies
  d_loss = .5 * (d_loss_real + d_loss_fake)
  d_acc = .5 * (d_acc_real + d_acc_fake)

  # ~~~ TRAIN GENERATOR ~~~

  noise = np.random.randn(batch_size, latent_dim)
  # Input is noise and target is vector of ones, because we want to trick discriminator into thinking images from generator are real
  g_loss = combined_model.train_on_batch(noise, ones)

  # Save the losses
  d_losses.append(d_loss)
  g_losses.append(g_loss)

  # Print epoch info every 100 epochs
  if epochs % 100 == 0:
    print(f"epoch: {epoch+1}/{epochs}, d_loss: {d_loss:.2f}, d_acc: {d_acc:.2f}, g_loss: {g_loss:.2f}")
  
  # If we're on sample period, generate more random images
  if epoch % sample_period == 0:
    sample_images(epoch)

# Note: this takes a while to train, but here are some notes from the video:
# For accuracy, we see that despite discriminator training, it never reaches high accuracy.
# This is because the generator improved as the discriminator does, which is what we want (GAN should generate realistic images)
# For loss, both hover around the same point throughout training because generator and discriminator feed off each other and approve in tandom

In [None]:
plt.plot(g_losses, label='g_losses')
plt.plot(d_losses, label='d_losses')
plt.legend()

In [None]:
# See image files created (file name = epoch number)
!ls gan_images

In [None]:
# Plot images, during early and later epochs
from skimage.io import imread
a = imread('gan_images/0.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/1000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/5000.png')
plt.imshow(a)

In [None]:
# 1/3 of the way through (things already look pretty good -> diminishing returns)
a = imread('gan_images/10000.png')
plt.imshow(a)

In [None]:
# 2/3 of the way through
a = imread('gan_images/20000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/29000.png')
plt.imshow(a)