In [None]:
try:
    # %tensorflow_version only exists in Colab.
    %tensorflow_version 2.x
except Exception:
    pass

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join('..', '..'))

if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from utils import get_strategy, save_images, Timer

tfl = tf.keras.layers

In [None]:
print(f'TensorFlow version: {tf.__version__}')

In [None]:
strategy = get_strategy()

In [None]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

In [None]:
def load_data():
    (X_train, _), (X_test, _) = tf.keras.datasets.mnist.load_data()
    
    # stack train and test images
    X = np.vstack((X_train, X_test))
    
    # add channel to images (required by tensorflow)
    X = np.expand_dims(X, axis=-1)
    
    # convert images to floats
    X = X.astype('float32')
    
    # normalize images to [-1,1]
    X = X / 127.5 - 1
    
    return X

In [None]:
train_images = load_data()

BUFFER_SIZE = train_images.shape[0]
BATCH_SIZE = 256 * strategy.num_replicas_in_sync

IMAGE_SHAPE = train_images.shape[1:]
LATENT_DIM = 100

# batch and shuffle the data
train_data = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
class Generator(tf.keras.Sequential):
    def __init__(self, image_shape, latent_dim, first_layer_channels=256):
        first_layer_rows = image_shape[0] // 4
        first_layer_cols = image_shape[1] // 4
        first_layer_shape=(first_layer_rows, first_layer_cols, first_layer_channels)
        
        super(Generator, self).__init__([
            tfl.InputLayer(input_shape=(latent_dim,)),

            tfl.Dense(np.prod(first_layer_shape), use_bias=False),
            tfl.BatchNormalization(),
            tfl.LeakyReLU(),
            tfl.Reshape(first_layer_shape),

            tfl.Conv2DTranspose(128, kernel_size=5, strides=1, padding='same', use_bias=False),
            tfl.BatchNormalization(),
            tfl.LeakyReLU(),

            tfl.Conv2DTranspose(64, kernel_size=5, strides=2, padding='same', use_bias=False),
            tfl.BatchNormalization(),
            tfl.LeakyReLU(),

            tfl.Conv2DTranspose(1, kernel_size=5, strides=2, padding='same', use_bias=False, activation='tanh')
        ], name='Generator')
        
        self.latent_dim = latent_dim
    
    def sample(self, samples=1, training=False):
        latent = tf.random.normal([samples, self.latent_dim])
        return self(latent, training)

In [None]:
def make_discriminator(image_shape):
    return tf.keras.Sequential([
        tfl.InputLayer(input_shape=image_shape),

        tfl.Conv2D(64, kernel_size=5, strides=2, padding='same'),
        tfl.LeakyReLU(),
        tfl.Dropout(0.3),

        tfl.Conv2D(128, kernel_size=5, strides=2, padding='same'),
        tfl.LeakyReLU(),
        tfl.Dropout(0.3),

        tfl.Flatten(),
        tfl.Dense(1)
    ], name='Discriminator')

In [None]:
class GAN(object):
    def __init__(self, generator, discriminator, generator_optimizer, discriminator_optimizer):
        self.generator = generator
        self.discriminator = discriminator
        self.gen_opt = generator_optimizer
        self.disc_opt = discriminator_optimizer
        self.cross_entropy = tf.losses.BinaryCrossentropy(from_logits=True)
        self.accuracy = tf.metrics.BinaryAccuracy()
    
    def train(self, images):
        with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            generated_images = self.generator.sample(BATCH_SIZE, training=True)

            real_output = self.discriminator(images, training=True)
            fake_output = self.discriminator(generated_images, training=True)

            real_y = tf.ones_like(real_output)
            fake_y = tf.zeros_like(fake_output)
            
            real_loss = self.cross_entropy(real_y, real_output)
            fake_loss = self.cross_entropy(fake_y, fake_output)

            gen_loss = self.cross_entropy(tf.ones_like(fake_output), fake_output)
            disc_loss = real_loss + fake_loss

        gen_grad = gen_tape.gradient(gen_loss, self.generator.trainable_variables)
        disc_grad = disc_tape.gradient(disc_loss, self.discriminator.trainable_variables)

        self.gen_opt.apply_gradients(zip(gen_grad, self.generator.trainable_variables))
        self.disc_opt.apply_gradients(zip(disc_grad, self.discriminator.trainable_variables))
        
        self.accuracy.reset_states()
        real_acc = self.accuracy(real_y, real_output)
        
        self.accuracy.reset_states()
        fake_acc = self.accuracy(fake_y, fake_output)
        
        return gen_loss, disc_loss, real_acc, fake_acc

In [None]:
with strategy.scope():
    generator = Generator(IMAGE_SHAPE, LATENT_DIM)
    discriminator = make_discriminator(IMAGE_SHAPE)
    gen_opt = tf.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    disc_opt = tf.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    gan = GAN(generator, discriminator, gen_opt, disc_opt)

In [None]:
generator.summary()

In [None]:
discriminator.summary()

In [None]:
generated_image = generator.sample()
plt.imshow(generated_image[0, :, :, 0], cmap='gray_r');

In [None]:
def train(model, dataset, epochs, print_interval=10):
    timer = Timer()
    timer.start()
    
    for epoch in range(1, epochs+1):
        timer.split()

        for batch, image_batch in dataset.enumerate(start=1):
            gen_loss, disc_loss, real_acc, fake_acc = model.train(image_batch)
            
            if batch % print_interval == 0:
                print(
                    f'Epoch {epoch:04d}, Batch {batch:03d},',
                    f'Loss: [G={gen_loss:.3f}, D={disc_loss:.3f}],',
                    f'Acc: [real={real_acc*100:.2f}, fake={fake_acc*100:.2f}],',
                    f'Time: [epoch={timer:%s}, total={timer:%e}]'
                )

        # save current digits generated from seed
        images = model.generator(SEED, training=False)
        save_images(images, epoch, IMAGE_DIR, PLOT_ROWS, PLOT_COLS)

In [None]:
IMAGE_DIR = 'images'
PLOT_ROWS = 5
PLOT_COLS = 5

# define seed used to plot images at each epoch
SEED = tf.random.normal([PLOT_ROWS * PLOT_COLS, LATENT_DIM])

# save initial "digits"
%mkdir -p "$IMAGE_DIR"
images = generator(SEED, training=False)
save_images(images, 0, IMAGE_DIR, PLOT_ROWS, PLOT_COLS)

In [None]:
EPOCHS = 100
train(gan, train_data, EPOCHS)

In [None]:
try:
    from google.colab import files
except ImportError:
    pass
else:
    files.download(IMAGE_DIR)