In [None]:
#imports
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers
from tensorflow.keras import losses
import numpy as np
from matplotlib import pyplot as plt
import cv2
import os
import time
from IPython import display
from tensorflow.keras import mixed_precision
from google.colab import files
import zipfile
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)


## Arguments for GAN
EPOCHS = 300
EXAMPLES_TO_GENERATE = 8
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512
BATCH_SIZE = 8
TF_RECORD_PATH = "output.tfrecord"
noise_dim = 100
seed = tf.random.normal([EXAMPLES_TO_GENERATE, noise_dim])

alpha_Discriminator = 0.2
alpha_Generator = 0.2
momentum_BatchNormalization = 0.8

################# Building Generator #################
def add_generator_layer(model, num_filters, kernel, num_strides): #from rectgan
  '''
  Adding Conv2DTranspose -> BatchNormalization -> LeakyReLU layers to the model
  '''
  model.add(layers.Conv2DTranspose(num_filters, kernel_size=kernel, strides=num_strides, padding='same', use_bias=False))
  model.add(layers.BatchNormalization(momentum=momentum_BatchNormalization)) #instance normalization
  model.add(layers.LeakyReLU())

def make_generator():
  '''
  Creating noise for the generator.
  '''
  model = tf.keras.Sequential() #add layers to generator model
  model.add(layers.Dense(16*16*512, input_shape=(100,), use_bias=False))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU(alpha=0.2))
  model.add(layers.Reshape((16, 16, 512)))

  add_generator_layer(model, 512, (5, 5), num_strides=(2, 2))
  add_generator_layer(model, 256, (5, 5), num_strides=(2, 2))
  add_generator_layer(model, 128, (5, 5), num_strides=(2, 2))
  add_generator_layer(model, 64, (5, 5), num_strides=(2, 2))
  add_generator_layer(model, 32, (5, 5), num_strides=(2, 2))
  model.add(layers.Conv2DTranspose(3,kernel_size=3,strides=1,padding='same',use_bias=False,kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02)))
  model.add(layers.Activation('tanh'))
  print("model summarY", model.summary())
  assert model.output_shape == (None, 512, 512, 3), "Generator output dimensions should be (512, 512, 3), aborting."
  return model



def add_discriminator_layer(model, filters, kernel, num_strides):
  model.add(layers.GaussianNoise(0.1)) #reddit tip
  model.add(layers.Conv2D(filters, kernel_size=kernel, strides=num_strides, padding='same', use_bias=False))
  model.add(layers.LeakyReLU(alpha_Discriminator))
  model.add(layers.Dropout(0.3))

def make_discriminator():
  """
  Building the discriminator.
  """
  model = tf.keras.Sequential()
  add_discriminator_layer(model, 8, (3, 3), num_strides=(2,2))
  add_discriminator_layer(model, 16, (3, 3), num_strides=(2,2))
  add_discriminator_layer(model, 32, (3, 3), num_strides=(2,2))
  add_discriminator_layer(model, 64, (3, 3), num_strides=(2,2))
  add_discriminator_layer(model, 128, (3, 3), num_strides=(2,2))
  add_discriminator_layer(model, 256, (3, 3), num_strides=(2,2))
  add_discriminator_layer(model, 512, (3, 3), num_strides=(2,2))

  model.add(layers.Flatten())
  # model.add(layers.Dense(128))
  model.add(layers.Dense(1))
  return model

def discriminator_loss(real_output, fake_output):
  real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(tf.ones_like(real_output), real_output) #real_output = 1
  fake_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)(tf.zeros_like(fake_output), fake_output) #fake_output = 0
  total_loss = real_loss + fake_loss
  return total_loss

def generator_loss(fake_output):
  return tf.keras.losses.BinaryCrossentropy(from_logits=True)(tf.ones_like(fake_output), fake_output) #array of 1s

######## Running the model ########


@tf.function #turns into a graph, for faster execution
def train_step(image_batch, generator, discriminator, generator_optimizer, discriminator_optimizer):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    image_batch = tf.cast(image_batch, tf.float32) / 255.0 #added to catch type mismatch
    #compute gradients for discriminator and generator using two different GradientTapes
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(image_batch, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    for i in range(predictions.shape[0]):  #for each image in the batch
        # Convert predictions to float32 before scaling and converting to uint8
        image = tf.cast(predictions[i, :, :, :] * 0.5 + 0.5, tf.float32) * 255
        image = tf.cast(image, tf.uint8).numpy()
        cv2.imwrite(f"GAN_output_images/img{i}/image{i}_epoch_{epoch:04d}.png", image)

def make_output_directories(examples_to_generate):
  '''
  Make one image output directory for each example to generate.
  This saves all work-in-progress images to each directory.
  '''
  if not os.path.exists('GAN_output_images'):
    os.makedirs('GAN_output_images')
  for i in range(examples_to_generate):
    new_dir = 'GAN_output_images/img' + str(i)
    if not os.path.exists(new_dir):
      os.makedirs(new_dir)

def decode_image(image):
  '''
  Map the image to the [-1, 1] range.
  '''
  image = tf.image.decode_jpeg(image, channels=3)
  image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])
  image = (tf.cast(image, tf.float32) / 127.5) - 1 #Normalize images to [-1, 1]
  return image


def read_tfrecord(example):
  '''
  Decode the image from the tfrecord file using decode_image
  '''
  tfrecord_format = {
    "image": tf.io.FixedLenFeature([], tf.string) #feature name for rectangular_houses
  }

  example = tf.io.parse_single_example(example, tfrecord_format)
  # image = decode_image(example['image'])
  image = tf.image.decode_image(example['image'], channels=3)
  return image

def train():
    # Load and prepare the dataset
    print("make tfr")
    dataset = tf.data.TFRecordDataset(TF_RECORD_PATH)
    dataset = dataset.map(read_tfrecord)
    dataset_size = sum(1 for _ in dataset)
    train_dataset = dataset.shuffle(buffer_size=10000).batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)

    print("make generator+discriminator")
    # Build generator and discriminator
    generator = make_generator()
    discriminator = make_discriminator()

    print("make optimizers")
    # Set up optimizers
    generator_optimizer = tf.keras.optimizers.Adam(1e-4)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

    # Set up checkpoints
    checkpoint_dir = './training_checkpoints'
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                     discriminator_optimizer=discriminator_optimizer,
                                     generator=generator,
                                     discriminator=discriminator)

    print("beginning training loop")
    # Training loop
    for epoch in range(EPOCHS):
        start = time.time()

        for image_batch in train_dataset:
            train_step(image_batch, generator, discriminator, generator_optimizer, discriminator_optimizer)

        # Produce images for the GIF
        if (epoch) % 2 == 0:
          print("saving image...")
          display.clear_output(wait=True)
          generate_and_save_images(generator, epoch + 1, seed)


        # Save the model every 5 epochs
        if (epoch + 1) % 5 == 0:
            print("saving checkpoint  ...")
            checkpoint.save(file_prefix=checkpoint_prefix)

        print(f'Time for epoch {epoch + 1} is {time.time() - start} sec')
        # Clear the session at the end of each epoch
        tf.keras.backend.clear_session()

    # Generate after the final epoch
    display.clear_output(wait=True)
    generate_and_save_images(generator, EPOCHS, seed)

def main():
  device_name = tf.test.gpu_device_name()
  if device_name != '/device:GPU:0':
    raise SystemError('GPU device not found')
  print('Found GPU at: {}'.format(device_name))
  with tf.device('/device:GPU:0'):
    train()

if __name__ == "__main__":
  main()

Time for epoch 7 is 78.13900589942932 sec
