In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [0]:
try:
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
from tensorflow.keras import layers

print(tf.__version__)

In [0]:
import numpy as np
np.random.seed(7)

In [0]:
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()

In [0]:
image_shape = (train_images.shape[1], train_images.shape[2], 1)

In [0]:
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 

In [0]:
buffer_size = 60000
batch_size = 100

In [0]:
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(buffer_size).batch(batch_size)

In [0]:
latent_dimension = 100
noise_shape = (latent_dimension,)

In [0]:
discriminator = tf.keras.Sequential(name='discriminator')

discriminator.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=image_shape))
discriminator.add(layers.LeakyReLU())
discriminator.add(layers.Dropout(0.3))

discriminator.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
discriminator.add(layers.LeakyReLU())
discriminator.add(layers.Dropout(0.3))

discriminator.add(layers.Flatten())
discriminator.add(layers.Dense(1))

discriminator.summary()

In [0]:
generator_shape = (7, 7, 256)
generator_size = np.prod(generator_shape)

In [0]:
generator = tf.keras.Sequential(name='generator')

generator.add(layers.Dense(generator_size, use_bias=False, input_shape=noise_shape))
generator.add(layers.BatchNormalization())
generator.add(layers.LeakyReLU())

generator.add(layers.Reshape(generator_shape))    

generator.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    
generator.add(layers.BatchNormalization())
generator.add(layers.LeakyReLU())

generator.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
generator.add(layers.BatchNormalization())
generator.add(layers.LeakyReLU())

generator.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))

generator.summary()

In [0]:
import matplotlib.pyplot as plot
from IPython import display

noise = tf.random.normal([1, latent_dimension])
generated_image = generator(noise, training=False)

plot.imshow(generated_image[0, :, :, 0], cmap='gray')

In [0]:
decision = discriminator(generated_image)
print (decision)

In [0]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [0]:
def compute_discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

In [0]:
def compute_generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [0]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [0]:
import matplotlib.pyplot as plot

def plot_generated(number_of_samples=10, dim=(1, 10), figsize=(12, 2)):

    noise = tf.random.normal([number_of_samples, latent_dimension])
    generated_images = generator.predict(noise)
    generated_images = generated_images*127.5 + 127.5
    generated_images = generated_images.reshape(number_of_samples, image_shape[0], image_shape[1])

    plot.figure(figsize=figsize)
    for i in range(number_of_samples):
        plot.subplot(dim[0], dim[1], i+1)
        plot.imshow(generated_images[i], cmap='gray')
        plot.axis('off')

    plot.tight_layout()
    plot.show()

In [0]:
def train(dataset, epochs=100, plot_frequency=1):    
      
    for epoch in range(epochs):  

      for images in dataset:

        noise = tf.random.normal([batch_size, latent_dimension])  

        with tf.GradientTape() as generator_tape, tf.GradientTape() as discriminator_tape:
            generated_images = generator(noise, training=True)   

            real_output = discriminator(images, training=True)
            fake_output = discriminator(generated_images, training=True)

            generator_loss = compute_generator_loss(fake_output)
            discriminator_loss = compute_discriminator_loss(real_output, fake_output)     

        gradients_of_generator = generator_tape.gradient(generator_loss, generator.trainable_variables)
        gradients_of_discriminator = discriminator_tape.gradient(discriminator_loss, discriminator.trainable_variables)

        generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

      if (epoch > 0)  and (epoch%plot_frequency == 0):           
            plot_generated()

In [0]:
train(train_dataset, epochs=50, plot_frequency=2)