<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tutorials/generative/pix2pix"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/generative/pix2pix.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/tutorials/generative/pix2pix.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/generative/pix2pix.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

In [0]:
import tensorflow as tf
import tensorflow_datasets as tfds
import os
import time

from matplotlib import pyplot as plt
from IPython import display

In [0]:
# !pip install -U tensorboard

In [0]:
_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'

path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
                                      origin=_URL,
                                      extract=True)

PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')

In [0]:
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

In [0]:
def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image)

  w = tf.shape(image)[1]

  w = w // 2
  real_image = image[:, :w, :]
  input_image = image[:, w:, :]

  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)

  return input_image, real_image

In [0]:
inp, re = load(PATH+'train/100.jpg')
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(inp/255.0)
plt.figure()
plt.imshow(re/255.0)

In [0]:
def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image

In [0]:
def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]

In [0]:
# normalizing the images to [-1, 1]

def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image

In [0]:
@tf.function()
def random_jitter(input_image, real_image):
  # resizing to 286 x 286 x 3
  input_image, real_image = resize(input_image, real_image, 286, 286)

  # randomly cropping to 256 x 256 x 3
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

In [0]:
plt.figure(figsize=(6, 6))
for i in range(4):
  rj_inp, rj_re = random_jitter(inp, re)
  plt.subplot(2, 2, i+1)
  plt.imshow(rj_inp/255.0)
  plt.axis('off')
plt.show()

In [0]:
def load_image_train(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

In [0]:
def load_image_test(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

## Input Pipeline

In [0]:
train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = tfds.as_numpy(train_dataset)

In [0]:
test_dataset = tf.data.Dataset.list_files(PATH+'test/*.jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = tfds.as_numpy(test_dataset)

## Build the Generator
  * The architecture of generator is a modified U-Net.
  * Each block in the encoder is (Conv -> Batchnorm -> Leaky ReLU)
  * Each block in the decoder is (Transposed Conv -> Batchnorm -> Dropout(applied to the first 3 blocks) -> ReLU)
  * There are skip connections between the encoder and decoder (as in U-Net).


In [0]:
OUTPUT_CHANNELS = 3

In [0]:
# !git clone 

In [0]:
# !pip uninstall flax
! pip install git+https://github.com/google/flax.git#egg=flax
import functools

import jax
import flax

import numpy as onp
import jax.numpy as jnp

In [0]:
class DownSample(flax.nn.Module):
  def apply(self, x, features, size, apply_batchnorm=True):
    # print("=====", features, size)
    x = flax.nn.Conv(x, features=features, kernel_size=(size, size), strides=(2, 2), padding='SAME', bias=False)
    if apply_batchnorm:
      x = flax.nn.BatchNorm(x)
    x = flax.nn.leaky_relu(x)
    return x

class UpSample(flax.nn.Module):
  def apply(self, x, features, size, apply_dropout=True):
    x = flax.nn.ConvTranspose(x, features=features, kernel_size=(size, size), strides=(2, 2), padding='SAME', bias=False)
    x = flax.nn.BatchNorm(x)
    if apply_dropout:
      x = flax.nn.dropout(x, 0.5)
    x = flax.nn.relu(x)
    return x

In [0]:
down_list = [[64, 4, False],
             [128, 4],
             [256, 4],
             [512, 4],
             [512, 4],
             [512, 4],
             [512, 4],
             [512, 4]]

up_list = [[512, 4, True],
           [512, 4, True],
           [512, 4, True],
           [512, 4],
           [256, 4],
           [128, 4],
           [64, 4]]

In [0]:
class Generator(flax.nn.Module):
  def apply(self, x):
    skips = []
    for down in down_list:
      x = DownSample(x, *down)
      # print('sss,', jnp.shape(x))
      skips.append(x)
    
    skips = list(reversed(skips[:-1]))
    # print(skips)
    for up, skip in zip(up_list, skips):
      x = UpSample(x, *up)
      # tf.keras.layers.Concatenate()([x, skip])
      # print('shpe = ', jnp.shape(x), jnp.shape(skip))
      x = jnp.concatenate((x,skip))
    
    x = flax.nn.ConvTranspose(x, features=OUTPUT_CHANNELS, kernel_size=(4,4), strides=(2,2), padding='SAME')
    x = flax.nn.tanh(x)
    return x

In [0]:
LAMBDA = 100

In [0]:
@jax.vmap
def binary_cross_entropy_loss(x, y):
  max_val = jnp.clip(x, 0, None)
  loss = x - x * y + max_val + jnp.log(np.exp(-max_val) + jnp.exp((-x - max_val)))
  return loss.mean()

In [0]:
@jax.vmap
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = binary_cross_entropy_loss(jnp.ones_like(disc_generated_output), disc_generated_output)
  
  l1_loss = jnp.mean(jnp.absolute(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  # think about negative
  return total_gen_loss, gan_loss, l1_loss

In [0]:
class Discriminator(flax.nn.Module):
  def apply(self, x):
    x = DownSample(x, 64, 4, False)
    x = DownSample(x, 128, 4)
    x = DownSample(x, 256, 4)

    x = jnp.pad(x, 1) # padding with zeros

    x = flax.nn.Conv(x, 512, kernel_size=(4,4), strides=(1,1), bias=False)
    x = flax.nn.BatchNorm(x)
    x = flax.nn.leaky_relu(x)

    x = jnp.pad(x, 1)

    x = flax.nn.Conv(x, 1, kernel_size=(4,4), strides=(1,1))
    
    return x

In [0]:
@jax.vmap
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = binary_cross_entropy_loss(jnp.ones_like(disc_real_output), disc_real_output)

  generated_loss = binary_cross_entropy_loss(jnp.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss
  
  return total_disc_loss

In [0]:
# create model
@functools.partial(jax.jit, static_argnums=(1, 2, 3))
def create_model(key, batch_size, image_size, model_def):
  input_shape = (batch_size, image_size, image_size, 3)
  with flax.nn.stateful() as init_state:
    with flax.nn.stochastic(jax.random.PRNGKey(0)):
      _, initial_params = model_def.init_by_shape(key, [(input_shape, jnp.float32)])
      model = flax.nn.Model(model_def, initial_params)
  return model, init_state

In [0]:
def create_optimizer(model, learning_rate, beta):
  optimizer_def = flax.optim.Adam(learning_rate=learning_rate,
                                 beta1=beta)
  optimizer = optimizer_def.create(model)
  optimizer = flax.jax_utils.replicate(optimizer)
  return optimizer

In [0]:
key = jax.random.PRNGKey(0)
generator_model, generator_state = create_model(key, BATCH_SIZE, IMG_HEIGHT, Generator)
discriminator_model, discriminator_state = create_model(key, BATCH_SIZE, IMG_HEIGHT, Discriminator)

In [0]:
generator_optimizer = create_optimizer(generator_model, 2e-4, 0.5)
discriminator_optimizer = create_optimizer(discriminator_model, 2e-4, 0.5)

In [0]:
def generate_images(model, test_input, tar):
  # with flax.nn.stateful(state):
  with flax.nn.stochastic(jax.random.PRNGKey(0)):
    prediction = model(test_input)
  # print("-----")
  plt.figure(figsize=(15,15))

  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

In [0]:
j = 0
for example_input, example_target in test_dataset:
  generate_images(generator_model, example_input, example_target)
  j+=1
  if j > 3:
    break

In [0]:
@jax.jit
def train_step(generator_opt, discriminator_opt, input_image, target_image):
  """Perform a single training step."""
  def loss_fn(gen_model, disc_model):
    """loss function used for training."""
    # with flax.nn.stateful(state) as new_state:
    with flax.nn.stochastic(jax.random.PRNGKey(0)):
      gen_output = gen_model(input_image)
      
      disc_real_output = disc_model(jnp.concatenate((input_image, target_image)))
      disc_generated_output = disc_model(jnp.concatenate((input_image, gen_output)))
    
    gen_total_loss, _, _ = generator_loss(disc_generated_output, gen_output, target_image)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
    
    return gen_total_loss, disc_loss

  step = generator_opt.state.step
  gen_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  disc_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  _, gen_grad = gen_grad_fn(generator_opt.target, discriminator_opt.target)
  (gen_total_loss, disc_loss), disc_grad = disc_grad_fn(generator_opt.target, discriminator_opt.target)

  new_gen_opt = generator_opt.apply_gradient(gen_grad)
  new_disc_opt = discriminator_opt.apply_gradient(disc_grad)
  # metrics = compute_metrics(logits, batch['label'])
  # metrics['learning_rate'] = lr
  return new_gen_opt, new_disc_opt, metrics

In [0]:
p_train_step = jax.pmap(functools.partial(train_step), axis_name='batch')

In [0]:
epochs = 100
for epoch in range(epochs):
  start = time.time()

  display.clear_output(wait=True)
  j=0
  for example_input, example_target in test_dataset:
    generate_images(generator_model, example_input, example_target)
    j+=1
    if j > 1:
      break
  print("Epoch: ", epoch)

  # Train
  
  for n, (input_image, target_image) in enumerate(test_dataset):
    print('.', end='')
    if n+1 % 100 == 0:
      print()
    p_train_step(generator_optimizer, discriminator_optimizer, input_image, target_image)
  print()

  # saving (checkpoint) the model every 20 epochs
  # if (epoch + 1) % 20 == 0:
  #   checkpoint.save(file_prefix = checkpoint_prefix)

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))

In [0]:
for example_input, example_target in test_dataset:
  # generate_images(generator, example_input, example_target)

In [0]:
def top():
  a = 3
  def k():
    a +=5
  print(a)
  k()
  print(a)


In [0]:
top()

In [0]:
    Example:
      input_shape = (batch_size, image_size, image_size, 3)
      model_output, initial_params = model.init_by_shape(jax.random.PRNGKey(0),
                                                         input_specs=[(input_shape, jnp.float32)])