In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import PIL
from tensorflow.keras import layers
import glob
import os
import time
import tensorflow_datasets as tfds
import cv2

In [None]:
print(tf.__version__)

In [None]:
BUFFER_SIZE = 1020
BATCH_SIZE = 256

In [None]:
train_dataset = tfds.load("oxford_flowers102", split=tfds.Split.TRAIN, batch_size=-1)
numpy_ds = tfds.as_numpy(train_dataset)

In [None]:
train_images_np = numpy_ds["image"]
print(train_images_np.shape)

In [None]:
train_images = []
for i in train_images_np:
  train_images.append(cv2.resize(i, (128,128))/255)

train_images = np.array(train_images, dtype='float32')
print(train_images.shape)
train_images_np = []

In [None]:
def make_generator():
  model = tf.keras.Sequential()
  
  model.add(layers.Dense(32 * 32 * 256, use_bias=False, input_shape=(100,)))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Reshape((32,32,256)))

  #Conv2DTranspose blocks to create the images
  model.add(layers.Conv2DTranspose(128, (5,5), strides=(1,1), padding='same', use_bias=False))
  assert model.output_shape == (None, 32, 32, 128)
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(64, (5,5), strides=(2,2), padding='same', use_bias=False))
  assert model.output_shape == (None, 64, 64, 64)
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(3, (5,5), strides=(2,2), padding='same', use_bias=False, activation='tanh'))
  assert model.output_shape == (None, 128, 128, 3)

  return model

In [None]:
#Test generator while untrained
generator = make_generator()
noise = tf.random.normal([1,100])
generated_image = generator(noise, training=False)
plt.imshow(generated_image[0, :, :, 0])
plt.show() 

In [None]:
def make_discriminator():
    model = tf.keras.Sequential()

    model.add(layers.Conv2D(64, (5,5), strides=(2,2), padding='same', input_shape=[128,128,3]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5,5), strides=(2,2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(256, (5,5), strides=(2,2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

In [None]:
#test the untrained discriminator
discriminator = make_discriminator()
decision = discriminator(generated_image)
print(decision)

In [None]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def generator_loss(fake_output):
  return cross_entropy(tf.ones_like(fake_output), fake_output)

In [None]:
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
checkpoint_dir = 'drive/My Drive/Flower_GAN/training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)


In [None]:
EPOCHS = 10000
noise_dim = 100
num_example_to_generate = 16

#Will reuse this seed overtime
seed = tf.random.normal([num_example_to_generate, noise_dim])

In [None]:
@tf.function
def train_step(images):
  noise = tf.random.normal(([BATCH_SIZE, noise_dim]))
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    #Generate the images from noise with training mode = True
    generated_images = generator(noise, training=True)

    real_output = discriminator(images, training=True)
    fake_output = discriminator(generated_images, training=True)

    gen_loss = generator_loss(fake_output)
    disc_loss = discriminator_loss(real_output, fake_output)

  gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
  gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [None]:
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    if (epoch + 1) % 1000 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)


    
    print("Time for epoch {} was {} seconds".format(epoch+1, time.time() - start))

In [None]:
dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

print(dataset)

In [None]:
train(dataset, EPOCHS)

In [None]:
#get a generated image
generated_test = generator(seed, training=False)
plt.imshow(generated_test[0, :, :, 0])
plt.show()

In [None]:
plt.imshow(train_images[0])
plt.show()

In [None]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))