In [None]:
# Change kernel to apache_beam

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time
import imageio

from IPython import display


In [None]:
tf.__version__

In [None]:
def load_dataset():
    (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
    train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
    train_images = (train_images - 127.5)/127.5  # Normalize image to [-1, 1]
    
    return train_images

train_images = load_dataset()

In [None]:
BUFFER_SIZE = 60000
BATCH_SIZE = 256

# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
def create_generator():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias = False, input_shape = (100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Reshape((7,7,256)))
    assert model.output_shape == (None, 7, 7, 256)
    
    model.add(layers.Conv2DTranspose(128, (5,5), padding = 'same', use_bias = False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    
    model.add(layers.Conv2DTranspose(64, (5,5), strides = (2,2), padding = 'same', use_bias = False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Conv2DTranspose(1, (5,5), strides=(2,2), padding = 'same', use_bias = False, 
              activation = 'tanh'))
    assert model.output_shape == (None, 28, 28, 1)
    
    return model

generator = create_generator()
    
    

In [None]:
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training = False)

# with tf.Session() as sess:
#     sess.run(tf.global_variables_initializer())
#     image = sess.run(generated_image)
#     plt.imshow(image[0, :, :, 0], cmap='gray')

# tf.reshape(generated_image, [28, 28, 1]).shape
plt.imshow(tf.reshape(generated_image, [28, 28]))
# generated_image.shape

In [None]:
def create_discriminator():
    model = tf.keras.Sequential()
    
    model.add(layers.Conv2D(64, (5,5), strides = (2,2), padding = 'same', input_shape = [28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    
    model.add(layers.Conv2D(128, (5,5), strides = (2,2), padding = 'same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    
    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    
    return model

discriminator = create_discriminator()
    

In [None]:
discriminator.summary()

In [None]:
decision = discriminator(generated_image)

print(decision)

In [None]:
# Define loss function and optimizers

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits = True)

In [None]:
# Define discriminator loss

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

# Define generator loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

generator_optimizer = tf.keras.optimizers.Adam(lr=1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(lr=1e-4)



In [None]:
# Save checkpoints

checkpoint_dir = './checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
checkpoint = tf.train.Checkpoint(generator_optimizer = generator_optimizer, 
                                discriminator_optimizer = discriminator_optimizer,
                                generator = generator,
                                discriminator = discriminator)

# Define training parameters

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

seed = tf.random.normal([num_examples_to_generate, noise_dim])

In [None]:
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training = True)
        
        real_output = discriminator(images, training = True)
        fake_output = discriminator(generated_images, training = True)
        
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
        
        
        gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
        gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
        
        
        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [None]:
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()
        
        for image_batch in dataset:
            train_step(image_batch)
            
            
        display.clear_output(wait = True)
        generate_and_save_images(generator, epoch + 1, seed)
        
        
        # save model every 15 epochs
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
            
    display.clear_output(wait = True)
    generate_and_save_images(generator, epochs, seed)

In [None]:
def generate_and_save_images(model, epoch, test_input):
    # Training is set to False for batch norm (testing mode)
    
    predictions = model(test_input, training = False)
    fig = plt.figure(figsize = (4,4))
    for i in range(predictions.shape[0]):
        plt.subplot(4,4, i+1)
        plt.imshow(predictions[i,:,:,0]*127.5 + 127.5, cmap = 'gray')
        plt.axis('off')
        
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()
        
        

In [None]:
%%time
train(train_dataset, EPOCHS)