<a href="https://colab.research.google.com/github/jdubkim/SAGAN-Keras/blob/master/Pix2Pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from __future__ import absolute_import, division, print_function

!pip install tensorflow-gpu==2.0.0-alpha0
import tensorflow as tf
import numpy as np
import argparse
import os

print(tf.__version__)

from tensorflow.keras.layers import Input, Dropout, Concatenate, Conv2D, UpSampling2D, LeakyReLU, BatchNormalization
from tensorflow.keras.optimizers import Adam

print(tf.__version__)

Collecting tensorflow-gpu==2.0.0-alpha0
[?25l  Downloading https://files.pythonhosted.org/packages/1a/66/32cffad095253219d53f6b6c2a436637bbe45ac4e7be0244557210dc3918/tensorflow_gpu-2.0.0a0-cp36-cp36m-manylinux1_x86_64.whl (332.1MB)
[K    100% |████████████████████████████████| 332.1MB 65kB/s 
Collecting tb-nightly<1.14.0a20190302,>=1.14.0a20190301 (from tensorflow-gpu==2.0.0-alpha0)
[?25l  Downloading https://files.pythonhosted.org/packages/a9/51/aa1d756644bf4624c03844115e4ac4058eff77acd786b26315f051a4b195/tb_nightly-1.14.0a20190301-py3-none-any.whl (3.0MB)
[K    100% |████████████████████████████████| 3.0MB 7.9MB/s 
Collecting tf-estimator-nightly<1.14.0.dev2019030116,>=1.14.0.dev2019030115 (from tensorflow-gpu==2.0.0-alpha0)
[?25l  Downloading https://files.pythonhosted.org/packages/13/82/f16063b4eed210dc2ab057930ac1da4fbe1e91b7b051a6c8370b401e6ae7/tf_estimator_nightly-1.14.0.dev2019030115-py2.py3-none-any.whl (411kB)
[K    100% |████████████████████████████████| 419kB 12.2MB/s

In [2]:
_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'

path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
                                      origin=_URL,  
                                      extract=True)

PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')

Downloading data from https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz


In [0]:
def parse_args():
  desc = "Keras Implementation of Self-Attention GAN"
  
  parser = argparse.ArgumentParser(desc)

  parser.add_argument('--phase', type=str, default='train')
  parser.add_argument('--dataset', type=str, default='facades', help='mnist | cifar10')

  parser.add_argument('--epoch', type=int, default=100, help='The number of epochs to run')
  parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
  parser.add_argument('--print_freq', type=int, default=500, help='The number of image_print_freqy')
  parser.add_argument('--save_freq', type=int, default=500, help='The number of ckpt_save_freq')

  parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for generator')
  parser.add_argument('--d_lr', type=float, default=0.0004, help='learning rate for discriminator')
  parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for Adam optimizer')
  parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for Adam optimizer')

  parser.add_argument('--z_dim', type=int, default=128, help='Dimension of noise vector')
  parser.add_argument('--up_sample', type=bool, default=True, help='using upsample-conv')
  parser.add_argument('--sn', type=bool, default=True, help='using spectral norm')
  parser.add_argument('--ld', type=float, default=10.0, help='gradient penalty lambda')
  parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')

  parser.add_argument('--img_shape', type=tuple, default=(256, 256, 3), help='The size of image')
  parser.add_argument('--sample_num', type=int, default=64, help='The number of sample images')

  parser.add_argument('--test_num', type=int, default=10, help='The number of images generated by the test')

  parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', help='Directory name to save checkpoints')
  parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save generated images')
  parser.add_argument('--log_dir', type=str, default='logs', help='Directory name to save training logs')
  parser.add_argument('--sample_dir', type=str, default='samples', help='Directory name to save samples on training')

  return parser.parse_args()

In [0]:
class Pix2Pix():
  
  def __init__(self):
    self.img_rows = 256
    self.img_cols = 256
    self.channels = 3
    self.img_shape = (self.img_rows, self.img_cols, self.channels)
    
    # self.dataset = dataset # dataset name
    self.dataset_name = 'facades'
    self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                 img_res=(self.img_rows, self.img_cols))
    
    patch = int(self.img_rows / 2**4)
    self.disc_patch = (patch, patch, 1)
    self.gf = 64
    self.df = 64
    
    optimizer = Adam(0.0002, 0.5)
    
    self.discriminator = self.Discriminator()
    self.discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
    
    self.generator = self.Generator()
    
    img_A = Input(shape=self.img_shape)
    img_B = Input(shape=self.img_shape)    
    
    fake_A = self.generator(img_B)
    
    self.discriminator.trainable = False
    
    valid = self.discriminator([fake_A, img_B])
    
    self.combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
    
    self.combined.compile(loss=['mse', 'mae'],
                         loss_weights=[1, 100],
                         optimizer=optimizer)
    
  
  def Generator(self):
    
    def conv2d(layer_input, filters, f_size=4, bn=True):
      d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
      d = LeakyReLU(alpha=0.2)(d)
      if bn:
        d = BatchNormalization(momentum=0.8)(d)
        return d
          
    
    def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
      """Layers used during upsampling"""
      u = UpSampling2D(size=2)(layer_input)
      u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
      if dropout_rate:
        u = Dropout(dropout_rate)(u)
      u = BatchNormalization(momentum=0.8)(u)
      u = Concatenate()([u, skip_input])
       
      return u
          
    d0 = Input(shape=self.img_shape)
    
    # Downsampling
    d1 = conv2d(d0, self.gf, bn=False)
    d2 = conv2d(d1, self.gf*2)
    d3 = conv2d(d2, self.gf*4)
    d4 = conv2d(d3, self.gf*8)
    d5 = conv2d(d4, self.gf*8)
    d6 = conv2d(d5, self.gf*8)
    d7 = conv2d(d6, self.gf*8)
    
    # Upsampling
    
    u1 = deconv2d(d7, d6, self.gf*8)
    u2 = deconv2d(u1, d5, self.gf*8)
    u3 = deconv2d(u2, d4, self.gf*8)
    u4 = deconv2d(u3, d3, self.gf*4)
    u5 = deconv2d(u4, d2, self.gf*2)
    u6 = deconv2d(u5, d1, self.gf)
    
    u7 = UpSampling2D(size=2)(u6)
    
    output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)
    
    return tf.keras.Model(d0, output_img)
  
  def Discriminator(self):
    
    def d_layer(layer_input, filters, f_size=4, bn=True):
      
      d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
      d = LeakyReLU(alpha=0.2)(d)
      if bn:
        d = BatchNormalization(momentum=0.8)(d)
      return d
    
    img_A = Input(shape=self.img_shape)
    img_B = Input(shape=self.img_shape)
    
    combined_imgs = Concatenate(axis=-1)([img_A, img_B])
    
    d1 = d_layer(combined_imgs, self.df, bn=False)
    d2 = d_layer(d1, self.df*2)
    d3 = d_layer(d2, self.df*4)
    d4 = d_layer(d3, self.df*8)
    
    validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)
    
    return tf.keras.Model([img_A, img_B], validity)
  
  def train(self, epochs, batch_size=1, sample_interval=50):
    
    valid = np.ones((batch_size, ) + self.disc_patch)
    fake = np.zeros((batch_size, ) + self.disc_patch)
    
    for epoch in range(epochs):
      for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):
        
        # Train Discriminator
        
        fake_A = self.generator(imgs_B)
        
        d_loss_real = self.discriminator.train_on_batch([imgs_A, imgs_B], valid)
        d_loss_fake = self.discriminator.train_on_batch([fake_A, imgs_B], fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # Train Generator
        
        g_loss = self.combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])
        
        print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f]" % 
             (epoch, epochs, batch_i, self.data_loader.n_batches, d_loss[0], 100*d_loss[1],
             g_loss[0]))
        
        if batch_i % sample_interval == 0:
          self.sample_images(epoch, batch_i)
          
  def sample_images(self, epoch, batch_i):
    
    if patch_check:
      os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
     
    r, c = 3, 3
    
    imgs_A, imgs_B = self.data_loader.load_data(batch_size=3, is_testing=True)
    fake_A = self.generator.predict(imgs_B)
    
    gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])
    
    gen_imgs = 0.5 * gen_imgs + 0.5
    
    titles = ['Condition', 'Generated', 'Original']
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
      for j in range(c):
        axs[i, j].imshow(gen_imgs[cnt])
        axs[i, j].set_title(titles[i])
        axs[i, j].axis('off')
        cnt += 1
    
    fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
    plt.close()    
    
        
    

In [0]:
def download_dataset(dataset_name, URL):
  
  # check if we have dataset or not
  if not os.path.exists('./' + dataset_name):
    path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
                                      origin=_URL,  
                                      extract=True)

In [0]:
class DataLoader():
  def __init__(self, img_shape, PATH):
    self.width = img_shape[0]
    self.height = img_shape[1]
    self.PATH = PATH
    
  def load_img(self, img_file):
    img = tf.io.read_file(img_file)
    img = tf.image.decode_jpeg(img)
    
    w = tf.shape(img)[1]
    
    w = w // 2
    real_img = img[:, :w, :]
    input_img = img[:, w:, :]
    
    real_img = tf.cast(real_img, tf.float32)
    input_img = tf.cast(input_img, tf.float32)
    
    return input_img, real_img
    
  def resize(self, input_img, real_img, height, width):
    input_img = tf.image.resize(input_img, [height, width], 
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_img = tf.image.resize(real_img, [height, width], 
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    return input_img, real_img
  
  def random_crop(self, input_img, real_img):
    stacked_img = tf.stack([input_img, real_img], axis=0)
    cropped_img = tf.image.random_crop(
        stacked_img, size=[2, self.height, self.width, 3])
    
    return cropped_img[0], cropped_img[1]
  
  def normalize(self, input_img, real_img):
    input_img = (input_img / 127.5) - 1
    real_img = (real_img / 127.5) - 1
    
    return input_img, real_img
  
  def random_jitter(self, input_img, real_img):
    # resizing to 286 x 286 x 3
    input_img, real_img = self.resize(input_img, real_img, 286, 286)
    
    input_img, real_img = self.random_crop(input_img, real_img)
    
    if tf.random.uniform(()) > 0.5:
      input_img = tf.image.flip_left_right(input_img)
      
  
  def load_img_train(self, img_file):
    input_img, real_img = self.load(img_file)
    input_img, real_img = self.random_jitter(input_img, real_img)
    input_img, real_img = self.normalize(input_img, real_img)
    
    return input_img, real_img
  
  def load_img_test(self, img_file):
    input_img, real_img = self.load(img_file)
    input_img, real_img = self.resize(input_img, real_img, self.height, self.width)
    input_img, real_img = self.normalize(input_img, real_img)
    
    return input_img, real_img
      

In [7]:
if __name__ == '__main__':
  #args = parse_args()
  data_loaer = DataLoader((256, 256, 3), _URL)
  pix2pix = Pix2Pix()
  pix2pix.train(epochs=200, batch_size=1, sample_interval=200)

TypeError: ignored