# Generate Images

Use TFGAN's GANEstimator to train a model on TPU to generate images of airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks.

Uses Deep Convolutional Generative Adversarial Network (DCGAN) architecture to adapt Convolutional Neural Networks (CNNs) for dealing with images.

## TPU connection

* Navigate to Edit → Notebook Settings → Hardware Accelerator = TPU
* Check that we can connect to the TPU

In [None]:
import os
import tensorflow.compat.v1 as tf
import pprint
assert 'COLAB_TPU_ADDR' in os.environ, 'Did you forget to switch to TPU?'
tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']

with tf.Session(tpu_address) as sess:
  devices = sess.list_devices()
pprint.pprint(devices)
device_is_tpu = [True if 'TPU' in str(x) else False for x in devices]
assert True in device_is_tpu, 'Did you forget to switch to TPU?'

[_DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:CPU:0, CPU, -1, -5438336312773099473),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 8589934592, 3940608670928494938),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, -1286911392309231460),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 2780663737291636112),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 2884526665050947995),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, -970520913650829580),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 3277700708642474361),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, -3585997035800732648),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 6958148156217588824),
 _DeviceAttributes(/job:tpu_wor

### Authentication

* To run on Google's free Cloud TPUs, set up a [Google Cloud Storage bucket](https://cloud.google.com/storage/) to store logs and checkpoints. 
* Running this notebook alone should fall under the [free pricing tier](https://cloud.google.com/storage/pricing).

In [None]:
import json
import os
import pprint
import re
import time
import tensorflow.compat.v1 as tf
import tensorflow_gcs_config

# Storage bucket for Estimator logs and training dataset.
bucket = 'likarajo_bucket' #@param {type:"string"}

assert bucket, 'Must specify an existing GCS bucket name'
print('Using bucket: {}'.format(bucket))

model_dir = 'gs://{}/{}'.format(bucket, time.strftime('tpuestimator-tfgan/%Y-%m-%d-%H-%M-%S'))
print('Using model dir: {}'.format(model_dir))

assert 'COLAB_TPU_ADDR' in os.environ, 'Missing TPU; did you request a TPU in Notebook Settings?'
tpu_address = 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])

from google.colab import auth
auth.authenticate_user()

# Upload credentials to TPU.
tf.config.experimental_connect_to_host(tpu_address)
tensorflow_gcs_config.configure_gcs_from_colab_auth()
# Now credentials are set for all future sessions on this TPU.

Using bucket: likarajo_bucket
Using model dir: gs://likarajo_bucket/tpuestimator-tfgan/2021-01-27-19-03-25


[None, <tf.Tensor: shape=(), dtype=int32, numpy=0>]

## Import necessary packages

In [None]:
import os

import tensorflow.compat.v1 as tf
# Disable noisy outputs.
tf.logging.set_verbosity(tf.logging.ERROR)
tf.autograph.set_verbosity(0, False)

try:
  import tensorflow_gan as tfgan
except ModuleNotFoundError:
  !pip install tensorflow-gan
  import tensorflow_gan as tfgan

import tensorflow_datasets as tfds

import tensorflow_hub as hub

import numpy as np

import matplotlib.pyplot as plt
# Allow matplotlib images to render immediately.
%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

## Get Data

[CIFAR10](https://wikipedia.org/wiki/CIFAR-10) dataset


### Input pipeline


In [None]:
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

dataset_dir = 'gs://{}/{}'.format(bucket, 'datasets')

def input_fn(mode, params):
  assert 'batch_size' in params
  assert 'noise_dims' in params
  bs = params['batch_size']
  nd = params['noise_dims']
  split = 'train' if mode == tf.estimator.ModeKeys.TRAIN else 'test'
  shuffle = (mode == tf.estimator.ModeKeys.TRAIN)
  just_noise = (mode == tf.estimator.ModeKeys.PREDICT)
  
  noise_ds = (tf.data.Dataset.from_tensors(0)
              .map(lambda _: tf.random_normal([bs, nd]))
              # If 'predict', just generate one batch.
              .repeat(1 if just_noise else None))
  
  if just_noise:
    return noise_ds

  def _preprocess(element):
    # Map [0, 255] to [-1, 1].
    images = (tf.cast(element['image'], tf.float32) - 127.5) / 127.5
    return images

  images_ds = (tfds.load('cifar10:3.*.*', split=split, data_dir=dataset_dir)
               .map(_preprocess, num_parallel_calls=4)
               .cache()
               .repeat())
  if shuffle:
    images_ds = images_ds.shuffle(
        buffer_size=10000, reshuffle_each_iteration=True)
  images_ds = (images_ds.batch(bs, drop_remainder=True)
               .prefetch(tf.data.experimental.AUTOTUNE))

  return tf.data.Dataset.zip((noise_ds, images_ds))


def noise_input_fn(params):
  np.random.seed(0)
  np_noise = np.random.randn(params['batch_size'], params['noise_dims'])
  return tf.data.Dataset.from_tensors(tf.constant(np_noise, dtype=tf.float32))

### Download the data

In [None]:
params = {'batch_size': 1, 'noise_dims':1}
input_fn(tf.estimator.ModeKeys.EVAL, params)

[1mDownloading and preparing dataset cifar10/3.0.2 (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to gs://likarajo_bucket/datasets/cifar10/3.0.2...[0m


InvalidArgumentError: ignored

### Sanity check the data


In [None]:
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import tensorflow_gan as tfgan

params = {'batch_size': 80, 'noise_dims':64}
ds = input_fn(tf.estimator.ModeKeys.EVAL, params)
numpy_imgs = next(iter(tfds.as_numpy(ds)))[1]
image_grid = tfgan.eval.python_image_grid(numpy_imgs, grid_shape=(8, 10))

def _show_image_grid(image_grid):
  plt.axis('off')
  plt.imshow((image_grid + 1.0) / 2.0,  # [-1, 1] -> [0, 1]
             aspect='auto')
  plt.show()
_show_image_grid(image_grid)

[1mDownloading and preparing dataset cifar10/3.0.2 (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to gs://likarajo_bucket/datasets/cifar10/3.0.2...[0m


InvalidArgumentError: ignored

## Neural Net Architecture

GAN
*  A generator that takes input noise and outputs images
*  A discriminator that takes images and outputs a probability of being real

### Network building functions

In [None]:
def _leaky_relu(x):
  return tf.nn.leaky_relu(x, alpha=0.2)


def _batch_norm(x, is_training, name):
  return tf.layers.batch_normalization(
      x, momentum=0.9, epsilon=1e-5, training=is_training, name=name)


def _dense(x, channels, name):
  return tf.layers.dense(
      x, channels,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
      name=name)


def _conv2d(x, filters, kernel_size, stride, name):
  return tf.layers.conv2d(
      x, filters, [kernel_size, kernel_size],
      strides=[stride, stride], padding='same',
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
      name=name)


def _deconv2d(x, filters, kernel_size, stride, name):
  return tf.layers.conv2d_transpose(
      x, filters, [kernel_size, kernel_size],
      strides=[stride, stride], padding='same',
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
      name=name)

### Discriminator

In [None]:
def discriminator(images, unused_conditioning, is_training=True,
                  scope='Discriminator'):
  """Discriminator for CIFAR images.

  Args:
    images: A Tensor of shape [batch size, width, height, channels], that can be
      either real or generated. It is the discriminator's goal to distinguish
      between the two.
    unused_conditioning: The TFGAN API can help with conditional GANs, which
      would require extra `condition` information to both the generator and the
      discriminator. Since this example is not conditional, we do not use this
      argument.
    is_training: If `True`, batch norm uses batch statistics. If `False`, batch
      norm uses the exponential moving average collected from population
      statistics.
    scope: A variable scope or string for the discriminator.

  Returns:
    A 1D Tensor of shape [batch size] representing the confidence that the
    images are real. The output can lie in [-inf, inf], with positive values
    indicating high confidence that the images are real.
  """
  with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
    x = _conv2d(images, 64, 5, 2, name='d_conv1')
    x = _leaky_relu(x)

    x = _conv2d(x, 128, 5, 2, name='d_conv2')
    x = _leaky_relu(_batch_norm(x, is_training, name='d_bn2'))

    x = _conv2d(x, 256, 5, 2, name='d_conv3')
    x = _leaky_relu(_batch_norm(x, is_training, name='d_bn3'))

    x = tf.reshape(x, [-1, 4 * 4 * 256])

    x = _dense(x, 1, name='d_fc_4')

    return x

### Generator

In [None]:
def generator(noise, is_training=True, scope='Generator'):
  """Generator to produce CIFAR images.

  Args:
    noise: A 2D Tensor of shape [batch size, noise dim]. Since this example
      does not use conditioning, this Tensor represents a noise vector of some
      kind that will be reshaped by the generator into CIFAR examples.
    is_training: If `True`, batch norm uses batch statistics. If `False`, batch
      norm uses the exponential moving average collected from population
      statistics.
    scope: A variable scope or string for the generator.

  Returns:
    A single Tensor with a batch of generated CIFAR images.
  """
  with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
    net = _dense(noise, 4096, name='g_fc1')
    net = tf.nn.relu(_batch_norm(net, is_training, name='g_bn1'))

    net = tf.reshape(net, [-1, 4, 4, 256])

    net = _deconv2d(net, 128, 5, 2, name='g_dconv2')
    net = tf.nn.relu(_batch_norm(net, is_training, name='g_bn2'))

    net = _deconv2d(net, 64, 4, 2, name='g_dconv3')
    net = tf.nn.relu(_batch_norm(net, is_training, name='g_bn3'))

    net = _deconv2d(net, 3, 4, 2, name='g_dconv4')
    net = tf.tanh(net)

    return net

### Evaluation Utilities

Metrics:
* `Inception Score`
* `Frechet Inception Distance`


In [None]:
import tensorflow.compat.v1 as tf
import tensorflow_gan as tfgan
import tensorflow_hub as hub
import numpy as np
tf.disable_eager_execution()

eval_batch_size = 4000 #@param
images_per_batch = 2000 #@param

def get_real_image_logits(num_images, classifier_model):
  """Returns an array with logits from real images and a CIFAR classifier.
  
  We normally want many thousands of examples to run eval. However, we can't fit
  inference for all of them in memory at once. Instead, we use TF-GAN eval utils
  to more efficiently manage memory.

  Args:
    num_images: Total number of images to produce logits for.
    classifier_model: A Python function that takes images and produces logits.

  Returns:
    A numpy array of logits of shape close to [num_images, ?].
  """
  ds = input_fn(tf.estimator.ModeKeys.TRAIN, 
                {'batch_size': images_per_batch, 'noise_dims': 1})
  iterator = tf.data.make_one_shot_iterator(ds)

  cifar_imgs = iterator.get_next()[1]
  real_logits = classifier_model(cifar_imgs)
  
  with tf.train.MonitoredSession() as sess:
    logits = sess.run(real_logits)
  assert len(logits.shape) == 2
  assert logits.shape[0] == num_images
  return logits

def init_global_real_logits():
  """Initialize a global variable with classifier logits for real data."""
  # We can hold all the real logits in memory at once, since CIFAR10 isn't that
  # big. Be sure to calculate it only once.
  global real_logits
  try:
    real_logits is not None
  except NameError:
    with tf.Graph().as_default():
      classifier_model = hub.Module("https://tfhub.dev/deepmind/ganeval-cifar10-convnet/1")
      real_logits = get_real_image_logits(
          eval_batch_size, classifier_model)
  assert real_logits.shape == (eval_batch_size, 10)
  
def calculate_real_data_classifier_score():
  """Calculate the classifier score on real data logits."""
  assert real_logits is not None
  classifier_score = tfgan.eval.classifier_score_from_logits(real_logits)
  with tf.train.MonitoredSession() as sess:
    cscore_real = sess.run(classifier_score)
  return cscore_real


def get_inception_score_and_fid(est):
  """Calculate our evaluation metrics."""
  global real_logits
  assert real_logits is not None

  tf.reset_default_graph()
  # We dont' want to hold all the images and activations at once, so use a
  # memory-efficient utility.
  def sample_fn():
    predictions = np.array([x['generated_data'] for x in est.predict(input_fn)])
    assert predictions.shape == (images_per_batch, 32, 32, 3)
    return predictions
  fake_imgs = tf.concat(
      [sample_fn() for _ in range(eval_batch_size // images_per_batch)], axis=0)

  classifier_fn = hub.Module("https://tfhub.dev/deepmind/ganeval-cifar10-convnet/1")
  fake_logits = classifier_fn(fake_imgs)
  fake_logits.shape.assert_is_compatible_with([eval_batch_size, 10])

  classifier_score = tfgan.eval.classifier_score_from_logits(fake_logits)
  fid = tfgan.eval.frechet_classifier_distance_from_activations(
      real_logits, fake_logits)

  with tf.train.MonitoredSession() as sess:
    cscore_np, fid_np = sess.run([classifier_score, fid])
  
  return cscore_np, fid_np

### GAN Estimator

TF-GAN's `TPUGANEstimator` extends TensorFlow's `TPUEstimator` class. `TPUEstimator` handles the details of deploying the network on a TPU.

In [None]:
import os
import tensorflow.compat.v1 as tf
import tensorflow_gan as tfgan
tf.disable_eager_execution()

noise_dims = 1024 #@param
generator_lr = 0.0002  #@param
discriminator_lr = 0.0002  #@param
train_batch_size = 1024  #@param

config = tf.estimator.tpu.RunConfig(
    model_dir=model_dir,
    master=tpu_address,
    tpu_config=tf.estimator.tpu.TPUConfig(iterations_per_loop=images_per_batch))
est = tfgan.estimator.TPUGANEstimator(
    generator_fn=generator,
    discriminator_fn=discriminator,
    generator_loss_fn=tfgan.losses.modified_generator_loss,
    discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
    generator_optimizer=tf.train.AdamOptimizer(generator_lr, 0.5),
    discriminator_optimizer=tf.train.AdamOptimizer(discriminator_lr, 0.5),
    joint_train=True,  # train G and D jointly instead of sequentially.
    train_batch_size=train_batch_size,
    predict_batch_size=images_per_batch,
    use_tpu=True,
    params={'noise_dims': noise_dims},
    config=config)

## GAN Training and Evaluation

In [None]:
import time
import matplotlib.pyplot as plt

max_steps = 50000 #@param
steps_per_eval = 5000 #@param

cur_step = 0
start_time = time.time()
cscores, fids, steps = [], [], []
init_global_real_logits()
print('Initialized classifier logits for real data.')
classifier_score_real_data = calculate_real_data_classifier_score()
print('Calculated classifier score for real data.')
while cur_step < max_steps:
  # Train for a fixed number of steps.
  start_step = cur_step
  step_to_stop_at = min(cur_step + steps_per_eval, max_steps)
  start = time.time()
  est.train(input_fn, max_steps=step_to_stop_at)
  end = time.time()
  cur_step = step_to_stop_at
  
  # Print some performance statistics.
  steps_taken = step_to_stop_at - start_step
  time_taken = end - start
  steps_per_sec = steps_taken / time_taken
  min_since_start = (time.time() - start_time) / 60.0
  print("Current step: %i, %.4f steps / sec, time since start: %.1f min" % (
      cur_step, steps_per_sec, min_since_start))
  
  # Calculate some evaluation metrics.
  eval_start_time = time.time()
  cscore, fid = get_inception_score_and_fid(est)
  eval_time = (time.time() - eval_start_time)
  cscores.append(cscore)
  fids.append(fid)
  steps.append(cur_step)
  print("Classifier score: %.2f / %.2f, FID: %.1f, "
        "time to calculate eval: %.2f sec" % (
            cscore, classifier_score_real_data, fid, eval_time))
  
  # Generate and show some predictions.
  predictions = np.array(
      [x['generated_data'] for x in est.predict(noise_input_fn)])[:80]
  image_grid = tfgan.eval.python_image_grid(predictions, grid_shape=(8, 10))
  _show_image_grid(image_grid)


### Metric Visualization

In [None]:
# Plot the metrics vs step.
plt.title('Frechet distance per step')
plt.plot(steps, fids)
plt.figure()
plt.title('Classifier Score per step')
plt.plot(steps, cscores)
plt.plot(steps, [classifier_score_real_data] * len(steps))
plt.figure()