In [1]:
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow import zeros_like, ones_like, function, GradientTape
from tensorflow.train import Checkpoint
from tensorflow.random import normal
from IPython import display
import time
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

## Preprocessing

In [2]:
BUFFER_SIZE = 60000
BATCH_SIZE = 256

## Generator

In [3]:
def make_generator_model() :

    #Establish the sequential model
    model = Sequential()

    #Start from latent space (REMEMBER - GENERATOR ARCHITECTURE MUST MIRROR DISCRIMINITOR ARCHITECTURE)
    #Noise comes in as (100,0) vector
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))

    #Transformer - Transforms output so that mean output -> 0 and SD output -> 1
    model.add(layers.BatchNormalization())

    #Activator - weights for generating an image?
    model.add(layers.LeakyReLU())

    #Reshape - first step towards final image size (separate out neurons)
    model.add(layers.Reshape((7, 7, 256)))

    #Check that output is correct shape
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    #Deconvolute from 7x7x256 to 7x7x128
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))

    #Check that output is correct shape
    assert model.output_shape == (None, 7, 7, 128)

    #Transformer - Transforms output so that mean output -> 0 and SD output -> 1
    model.add(layers.BatchNormalization())

    #Activator - weights for generating an image?
    model.add(layers.LeakyReLU())

    #Deconvolute from 7x7x128 to 14x14x64
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))

    #Check that output is correct shape
    assert model.output_shape == (None, 14, 14, 64)

    #Transformer - Transforms output so that mean output -> 0 and SD output -> 1
    model.add(layers.BatchNormalization())

    #Activator - weights for generating an image?
    model.add(layers.LeakyReLU())

    #Deconvolute from 14x14x64 to 28x28x64
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

    #Transformer - Transforms output so that mean output -> 0 and SD output -> 1
    model.add(layers.BatchNormalization())

    #Activator - weights for generating an image?
    model.add(layers.LeakyReLU())

    #Deconvolute from 28x28x64 to 56x56x3
    model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

    assert model.output_shape == (None, 56, 56, 3)

    return model


In [4]:
def test_generator_model_untrained(model) :
    noise = normal([1,100])
    generated_image = model(noise, training=False)
    generated_image_descaled = generated_image[0] * 255 
    plt.imshow(generated_image_descaled[:,:,:])
    print(generated_image_descaled)

In [5]:
generator = make_generator_model()

In [6]:
generator.summary()

In [7]:
test = test_generator_model_untrained(generator)

## Discriminator

In [8]:
def make_discriminator_model():
    model = Sequential([
        layers.Conv2D(64, (5,5), strides=(2, 2), padding='same', input_shape=[56, 56 , 3]),
#         layers.MaxPool2D(pool_size = (2,2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.4),
        
        
        layers.Conv2D(128, (3,3), strides = (2,2), padding = 'same'),
        layers.LeakyReLU(),
        layers.Dropout(0.4),
        
        layers.Flatten(),
        layers.Dense(1, activation = 'sigmoid' )])

    return model

In [9]:
discriminator = make_discriminator_model()

In [10]:
discriminator.summary()

In [11]:
# decision = discriminator(generated_image)
# print (decision)

## Loss and Optimizers

### Dicsriminator loss:

In [12]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = BinaryCrossentropy(from_logits=True)

In [13]:
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(ones_like(real_output), real_output)
    fake_loss = cross_entropy(zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

### Generator loss:

In [14]:
def generator_loss(fake_output):
    return cross_entropy(ones_like(fake_output), fake_output)

### Optimizers: 

In [15]:
generator_optimizer = Adam(learning_rate = 0.0001)
discriminator_optimizer = Adam(learning_rate = 0.0001)

## Checkpoints

In [16]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

## Training

### Training loops

In [17]:
epochs = 50
noise_dim = 100
num_examples_to_generate = 16

In [18]:
@function
def train_step(images):
    noise = random.normal([BATCH_SIZE, noise_dim])

    with GradientTape() as gen_tape, GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

### Train function

In [19]:
def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, :] * 255, )
        plt.axis('off')

    plt.savefig(f'image_at_epoch_{epoch}.png',)
    plt.show()

In [20]:
def train(dataset, epochs):
    for epoch in range(epochs):
        start = time.time()

    for image_batch in dataset:
        train_step(image_batch)

    # Produce images as you go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1)

    # Save the model every 10 epochs
    if (epoch + 1) % 10 == 0:
        checkpoint.save(file_prefix = checkpoint_prefix)

        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                               epochs)