Install necessary tools for mounting Google Drive.

In [0]:
!apt-get install -y -qq software-properties-common python-software-properties module-init-tools
!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
!apt-get update -qq 2>&1 > /dev/null
!apt-get -y install -qq google-drive-ocamlfuse fuse
from google.colab import auth
auth.authenticate_user()
from oauth2client.client import GoogleCredentials
creds = GoogleCredentials.get_application_default()
import getpass
!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass()
!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}

Mount Google Drive.

In [0]:
!fusermount -u drive

In [0]:
!mkdir -p drive
!google-drive-ocamlfuse drive

Run this cell in order to check whether google drive is mounted.

In [0]:
!ls && ls drive 

Imports.

In [0]:
from absl import flags
import numpy as np
from six.moves import xrange 
import tensorflow as tf

from tensorflow.contrib.gan.python.eval.python import eval_utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.summary import summary

layers = tf.contrib.layers
tfgan = tf.contrib.gan

Check if GPU is avaliable.

In [0]:
tf.test.gpu_device_name()

CycleGAN generator with ResNet blocks.

In [0]:
def cyclegan_arg_scope(instance_norm_center=True,
                       instance_norm_scale=True,
                       instance_norm_epsilon=0.001,
                       weights_init_stddev=0.02,
                       weight_decay=0.0):
  instance_norm_params = {
      'center': instance_norm_center,
      'scale': instance_norm_scale,
  }

  weights_regularizer = None
  if weight_decay and weight_decay > 0.0:
    weights_regularizer = layers.l2_regularizer(weight_decay)

  with tf.contrib.framework.arg_scope(
      [layers.conv2d],
      normalizer_fn=layers.instance_norm,
      normalizer_params=instance_norm_params,
      weights_initializer=tf.random_normal_initializer(0, weights_init_stddev),
      weights_regularizer=weights_regularizer) as sc:
    return sc


def cyclegan_upsample(net, num_outputs, stride, method='conv2d_transpose'):
  with tf.variable_scope('upconv'):
    net_shape = tf.shape(net)
    height = net_shape[1]
    width = net_shape[2]

    spatial_pad_1 = np.array([[0, 0], [1, 1], [1, 1], [0, 0]])

    if method == 'nn_upsample_conv':
      net = tf.image.resize_nearest_neighbor(
          net, [stride[0] * height, stride[1] * width])
      net = tf.pad(net, spatial_pad_1, 'REFLECT')
      net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
    elif method == 'bilinear_upsample_conv':
      net = tf.image.resize_bilinear(
          net, [stride[0] * height, stride[1] * width])
      net = tf.pad(net, spatial_pad_1, 'REFLECT')
      net = layers.conv2d(net, num_outputs, kernel_size=[3, 3], padding='valid')
    elif method == 'conv2d_transpose':
      net = layers.conv2d_transpose(
          net, num_outputs, kernel_size=[3, 3], stride=stride, padding='valid')
      net = net[:, 1:, 1:, :]
    else:
      raise ValueError('Unknown method: [%s]', method)

    return net


def _dynamic_or_static_shape(tensor):
  shape = tf.shape(tensor)
  static_shape = tf.contrib.util.constant_value(shape)
  return static_shape if static_shape is not None else shape


def cyclegan_generator_resnet(images,
                              arg_scope_fn=cyclegan_arg_scope,
                              num_resnet_blocks=6,
                              num_filters=64,
                              upsample_fn=cyclegan_upsample,
                              kernel_size=3,
                              num_outputs=3,
                              tanh_linear_slope=0.0,
                              is_training=False):
  del is_training

  end_points = {}

  input_size = images.shape.as_list()
  height, width = input_size[1], input_size[2]
  if height and height % 4 != 0:
    raise ValueError('The input height must be a multiple of 4.')
  if width and width % 4 != 0:
    raise ValueError('The input width must be a multiple of 4.')

  if not isinstance(kernel_size, (list, tuple)):
    kernel_size = [kernel_size, kernel_size]

  kernel_height = kernel_size[0]
  kernel_width = kernel_size[1]
  pad_top = (kernel_height - 1) // 2
  pad_bottom = kernel_height // 2
  pad_left = (kernel_width - 1) // 2
  pad_right = kernel_width // 2
  paddings = np.array(
      [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]],
      dtype=np.int32)
  spatial_pad_3 = np.array([[0, 0], [3, 3], [3, 3], [0, 0]])

  with tf.contrib.framework.arg_scope(arg_scope_fn()):
    with tf.variable_scope('input'):
      # 7x7 input stage
      net = tf.pad(images, spatial_pad_3, 'REFLECT')
      net = layers.conv2d(
          net, num_filters, kernel_size=[7, 7], padding='VALID')
      end_points['encoder_0'] = net

    with tf.variable_scope('encoder'):
      with tf.contrib.framework.arg_scope(
          [layers.conv2d],
          kernel_size=kernel_size,
          stride=2,
          activation_fn=tf.nn.relu,
          padding='VALID'):

        net = tf.pad(net, paddings, 'REFLECT')
        net = layers.conv2d(net, num_filters * 2)
        end_points['encoder_1'] = net
        net = tf.pad(net, paddings, 'REFLECT')
        net = layers.conv2d(net, num_filters * 4)
        end_points['encoder_2'] = net

    with tf.variable_scope('residual_blocks'):
      with tf.contrib.framework.arg_scope(
          [layers.conv2d],
          kernel_size=kernel_size,
          stride=1,
          activation_fn=tf.nn.relu,
          padding='VALID'):
        for block_id in xrange(num_resnet_blocks):
          with tf.variable_scope('block_{}'.format(block_id)):
            res_net = tf.pad(net, paddings, 'REFLECT')
            res_net = layers.conv2d(res_net, num_filters * 4)
            res_net = tf.pad(res_net, paddings, 'REFLECT')
            res_net = layers.conv2d(res_net, num_filters * 4,
                                    activation_fn=None)
            net += res_net

            end_points['resnet_block_%d' % block_id] = net

    with tf.variable_scope('decoder'):
      with tf.contrib.framework.arg_scope(
          [layers.conv2d],
          kernel_size=kernel_size,
          stride=1,
          activation_fn=tf.nn.relu):

        with tf.variable_scope('decoder1'):
          net = upsample_fn(net, num_outputs=num_filters * 2, stride=[2, 2])
        end_points['decoder1'] = net

        with tf.variable_scope('decoder2'):
          net = upsample_fn(net, num_outputs=num_filters, stride=[2, 2])
        end_points['decoder2'] = net

    with tf.variable_scope('output'):
      net = tf.pad(net, spatial_pad_3, 'REFLECT')
      logits = layers.conv2d(
          net,
          num_outputs, [7, 7],
          activation_fn=None,
          normalizer_fn=None,
          padding='valid')
      logits = tf.reshape(logits, _dynamic_or_static_shape(images))

      end_points['logits'] = logits
      end_points['predictions'] = tf.tanh(logits) + logits * tanh_linear_slope

  return end_points['predictions'], end_points

Pix2pix discriminator.

In [0]:
def pix2pix_arg_scope():
  instance_norm_params = {
      'center': True,
      'scale': True,
      'epsilon': 0.00001,
  }

  with tf.contrib.framework.arg_scope(
      [layers.conv2d, layers.conv2d_transpose],
      normalizer_fn=layers.instance_norm,
      normalizer_params=instance_norm_params,
      weights_initializer=tf.random_normal_initializer(0, 0.02)) as sc:
    return sc


def pix2pix_discriminator(net, num_filters, padding=2, is_training=False):
  del is_training
  end_points = {}

  num_layers = len(num_filters)

  def padded(net, scope):
    if padding:
      with tf.variable_scope(scope):
        spatial_pad = tf.constant(
            [[0, 0], [padding, padding], [padding, padding], [0, 0]],
            dtype=tf.int32)
        return tf.pad(net, spatial_pad, 'REFLECT')
    else:
      return net

  with tf.contrib.framework.arg_scope(
      [layers.conv2d],
      kernel_size=[4, 4],
      stride=2,
      padding='valid',
      activation_fn=tf.nn.leaky_relu):

    net = layers.conv2d(
        padded(net, 'conv0'), num_filters[0], normalizer_fn=None, scope='conv0')

    end_points['conv0'] = net

    for i in range(1, num_layers - 1):
      net = layers.conv2d(
          padded(net, 'conv%d' % i), num_filters[i], scope='conv%d' % i)
      end_points['conv%d' % i] = net

    net = layers.conv2d(
        padded(net, 'conv%d' % (num_layers - 1)),
        num_filters[-1],
        stride=1,
        scope='conv%d' % (num_layers - 1))
    end_points['conv%d' % (num_layers - 1)] = net

    logits = layers.conv2d(
        padded(net, 'conv%d' % num_layers),
        1,
        stride=1,
        activation_fn=None,
        normalizer_fn=None,
        scope='conv%d' % num_layers)
    end_points['logits'] = logits
    end_points['predictions'] = tf.sigmoid(logits)
  return logits, end_points

Networks (wrappers).

In [0]:
def generator(input_images, num_resnet_blocks):
  input_images.shape.assert_has_rank(4)
  input_size = input_images.shape.as_list()
  channels = input_size[-1]
  if channels is None:
    raise ValueError(
        'Last dimension shape must be known but is None: %s' % input_size)
  with tf.contrib.framework.arg_scope(cyclegan_arg_scope()):
    output_images, _ = cyclegan_generator_resnet(
        input_images, num_outputs=channels,
        num_resnet_blocks=num_resnet_blocks)
  return output_images


def discriminator(image_batch, unused_conditioning=None):
  with tf.contrib.framework.arg_scope(pix2pix_arg_scope()):
    logits_4d, _ = pix2pix_discriminator(
        image_batch, num_filters=[64, 128, 256, 512])
    logits_4d.shape.assert_has_rank(4)
  logits_2d = tf.contrib.layers.flatten(logits_4d)
  return logits_2d

Data Provider.

In [0]:
def normalize_image(image):
  return (tf.to_float(image) - 127.5) / 127.5


def undo_normalize_image(normalized_image):
  normalized_image = np.squeeze(normalized_image, axis=0)
  return np.uint8(normalized_image * 127.5 + 127.5)


def _sample_patch(image, patch_size, training=True):
  image_shape = tf.shape(image)
  height, width = image_shape[0], image_shape[1]
  target_size = tf.minimum(height, width)
  image = tf.image.resize_image_with_crop_or_pad(image, target_size,
                                                 target_size)
  image = tf.expand_dims(image, axis=0)

  if training:
    scale_size = int(patch_size * 1.172)
    image = tf.image.resize_images(image, [scale_size, scale_size])
    image = tf.squeeze(image, axis=0)
    seed = 9
    image = tf.image.random_flip_left_right(image, seed=seed)
    image = tf.random_crop(
        image, [patch_size, patch_size, image_shape[2]], seed=seed)
  else:
    image = tf.image.resize_images(image, [patch_size, patch_size])
    image = tf.squeeze(image, axis=0)

  image = tf.tile(image, [1, 1, tf.maximum(1, 4 - tf.shape(image)[2])])
  image = tf.slice(image, [0, 0, 0], [patch_size, patch_size, 3])
  return image


def full_image_to_patch(image, patch_size, training=True):
  image = normalize_image(image)
  image_patch = _sample_patch(image, patch_size, training=training)
  image_patch.shape.assert_is_compatible_with([patch_size, patch_size, 3])
  return image_patch


def parse_dataset(filename, patch_size):
  image_string = tf.read_file(filename)
  image_bytes = tf.image.decode_image(image_string)
  image_patch = full_image_to_patch(image_bytes, patch_size)
  return image_patch


def provide_custom_datasets(image_file_patterns,
                            batch_size,
                            shuffle=True,
                            num_threads=1,
                            patch_size=128):

  outputs = []

  for p in image_file_patterns:
    filenames = tf.gfile.Glob(p)
    dataset = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = dataset.map(lambda f: parse_dataset(f, patch_size))
    dataset = dataset.shuffle(1500, seed=5)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat()
    im = dataset.make_one_shot_iterator().get_next()
    outputs.append(im)
  return outputs

Training.

In [0]:
image_set_x_file_pattern = 'drive/app/horse2zebra/trainA/*.jpg'
image_set_y_file_pattern = 'drive/app/horse2zebra/trainB/*.jpg'
batch_size = 1
patch_size = 256
master = ''
train_log_dir = 'drive/app/checkpoints'
generator_lr = 0.0001
discriminator_lr = 0.0001
max_number_of_steps = 500000
ps_tasks = 0
task = 0
cycle_consistency_loss_weight = 10.0
num_resnet_blocks = 9


def _assert_is_image(data):
  data.shape.assert_has_rank(4)
  data.shape[1:].assert_is_fully_defined()


def add_cyclegan_image_summaries(cyclegan_model, batch_size):
  if not isinstance(cyclegan_model, tf.contrib.gan.CycleGANModel):
    raise ValueError('`cyclegan_model` was not a CycleGANModel. Instead, was '
                     '%s' % type(cyclegan_model))

  _assert_is_image(cyclegan_model.model_x2y.generator_inputs)
  _assert_is_image(cyclegan_model.model_x2y.generated_data)
  _assert_is_image(cyclegan_model.reconstructed_x)
  _assert_is_image(cyclegan_model.model_y2x.generator_inputs)
  _assert_is_image(cyclegan_model.model_y2x.generated_data)
  _assert_is_image(cyclegan_model.reconstructed_y)

  def _add_comparison_summary(gan_model, reconstructions):
    image_list = (
        array_ops.unstack(gan_model.generator_inputs[:1], num=batch_size) +
        array_ops.unstack(gan_model.generated_data[:1], num=batch_size) +
        array_ops.unstack(reconstructions[:1], num=batch_size))
    summary.image(
        'image_comparison', eval_utils.image_reshaper(
            image_list, num_cols=len(image_list)), max_outputs=1)

  with ops.name_scope('x2y_image_comparison_summaries'):
    _add_comparison_summary(
        cyclegan_model.model_x2y, cyclegan_model.reconstructed_x)
  with ops.name_scope('y2x_image_comparison_summaries'):
    _add_comparison_summary(
        cyclegan_model.model_y2x, cyclegan_model.reconstructed_y)


def _define_model(images_x, images_y):
  cyclegan_model = tfgan.cyclegan_model(
      generator_fn=lambda x: generator(x, num_resnet_blocks),
      discriminator_fn=discriminator,
      data_x=images_x,
      data_y=images_y)

  add_cyclegan_image_summaries(cyclegan_model, batch_size)

  return cyclegan_model


def _get_lr(base_lr):
  global_step = tf.train.get_or_create_global_step()
  lr_constant_steps = max_number_of_steps // 2

  def _lr_decay():
    return tf.train.polynomial_decay(
        learning_rate=base_lr,
        global_step=(global_step - lr_constant_steps),
        decay_steps=(max_number_of_steps - lr_constant_steps),
        end_learning_rate=0.0)

  return tf.cond(global_step < lr_constant_steps, lambda: base_lr, _lr_decay)


def _get_optimizer(gen_lr, dis_lr):
  gen_opt = tf.train.AdamOptimizer(
      gen_lr, beta1=0.5, beta2=0.9, use_locking=True)
  dis_opt = tf.train.AdamOptimizer(
      dis_lr, beta1=0.5, beta2=0.9, use_locking=True)
  return gen_opt, dis_opt


def _define_train_ops(cyclegan_model, cyclegan_loss):
  gen_lr = _get_lr(generator_lr)
  dis_lr = _get_lr(discriminator_lr)
  gen_opt, dis_opt = _get_optimizer(gen_lr, dis_lr)
  train_ops = tfgan.gan_train_ops(
      cyclegan_model,
      cyclegan_loss,
      generator_optimizer=gen_opt,
      discriminator_optimizer=dis_opt,
      summarize_gradients=True,
      colocate_gradients_with_ops=True,
      aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

  tf.summary.scalar('generator_lr', gen_lr)
  tf.summary.scalar('discriminator_lr', dis_lr)
  return train_ops


def main():
  tf.set_random_seed(10)

  if not tf.gfile.Exists(train_log_dir):
    tf.gfile.MakeDirs(train_log_dir)

  with tf.device(tf.train.replica_device_setter(ps_tasks)):
    with tf.name_scope('inputs'):
      images_x, images_y = provide_custom_datasets(
          [image_set_x_file_pattern, image_set_y_file_pattern],
          batch_size=batch_size,
          patch_size=patch_size)

    cyclegan_model = _define_model(images_x, images_y)

    cyclegan_loss = tfgan.cyclegan_loss(
        cyclegan_model,
        cycle_consistency_loss_weight=cycle_consistency_loss_weight,
        tensor_pool_fn=tfgan.features.tensor_pool)

    train_ops = _define_train_ops(cyclegan_model, cyclegan_loss)

    train_steps = tfgan.GANTrainSteps(1, 1)
    status_message = tf.string_join(
        [
            'Starting train step: ',
            tf.as_string(tf.train.get_or_create_global_step())
        ],
        name='status_message')
    if not max_number_of_steps:
      return

    tfgan.gan_train(
        train_ops,
        train_log_dir,
        get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps),
        hooks=[
            tf.train.StopAtStepHook(num_steps=max_number_of_steps),
            tf.train.LoggingTensorHook([status_message], every_n_iter=10)
        ],
        master=master,
        is_chief=task == 0)

Start training.

In [0]:
main()