## Library loading

In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals
# !pip install -q tensorflow-gpu==2.0.0-alpha0

In [2]:
import tensorflow as tf
import numpy as np 
import os
from tensorflow.keras.layers import Dense, Flatten, Conv2D, ReLU, Conv2DTranspose, LeakyReLU, Layer, ZeroPadding2D
from tensorflow.keras.activations import tanh
import time
print(tf.__version__)

2.0.0-alpha0


## Auxliary function

In [0]:
def random_target_domain_generation():
  ''' target domain generation '''
  
  target_domain_1 = np.random.uniform(low=0., high=3., size=batch_size).astype(np.int32) # facial expression attributes
  target_domain_1 = tf.one_hot(target_domain_1, depth=3)
  target_domain_2 = np.random.randint(2, size=(batch_size,2)) # male, young attributes
  target_domain =  np.concatenate([target_domain_1, target_domain_2], axis=-1)
  return target_domain
            
def generate_and_save_images(model, epoch, test_input):
  ''' from tensorflow homepage '''
  ''' batchnorm 바꿔라 '''
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()
  
def input_merge(images, domain):
  ''' input image and domain merge'''
  batch_size = images.shape[0]
  image_size = images.shape[1]
  channels = domain.shape[1]
  merged = np.zeros([batch_size,image_size,image_size,channels])
  for batch in range(batch_size):
    temp = tf.broadcast_to(domain[batch], [image_size,image_size,channels])
    merged[batch] = temp
  merged = tf.concat([images, merged], axis=-1)
  return merged  
  
    
class InstanceNormalization(tf.keras.layers.Layer):
  '''InstanceNormalization for only 4-rank Tensor (image data)'''

  def __init__(self, epsilon=1e-5):
    super(InstanceNormalization, self).__init__()
    self.epsilon = epsilon

  def build(self, input_shape):
    shape = tf.TensorShape(input_shape)
    param_shape = shape[-1]
    self.gamma = self.add_weight(name='gamma',
                                 shape=param_shape,
                                 initializer='ones',
                                 trainable=True)
    self.beta = self.add_weight(name='beta',
                                shape=param_shape,
                                initializer='zeros',
                                trainable=True)
    super(InstanceNormalization, self).build(input_shape)

  def call(self, inputs):
    input_shape = inputs.get_shape()
    HW = input_shape[1]*input_shape[2] 
    u_ti = 1/HW*np.sum(inputs, axis=(1,2))   # 2x3  
    for _ in range(2):
      u_ti = np.stack((u_ti, )*input_shape[1], axis=-1) # 2x3x128x128
    u_ti = np.swapaxes(u_ti,1,3) # 2x128x128x3  
    var_ti = 1/HW*np.sum((inputs - u_ti), axis=(1,2))**2 # 2x3  
    for _ in range(2):
      var_ti = np.stack((var_ti, )*input_shape[1], axis=-1) # 2x3x128x128
    var_ti = np.swapaxes(var_ti,1,3) # 2x128x128x3                      
    y_tijk = (inputs - u_ti) / np.sqrt(var_ti +  self.epsilon)  # 2x128x128x3
    return self.gamma * y_tijk + self.beta   
  


## Build Generator & Discriminator

In [0]:
class Downsampling_Part(tf.keras.Model):
  ''' Downsampling part of Generator'''
  
  def __init__(self):
    super(Downsampling_Part, self).__init__()
    self.conv1 = Conv2D(64, kernel_size = 7, strides = 1, padding = 'valid')
    self.conv2 = Conv2D(128, kernel_size = 4, strides = 2, padding = 'valid')
    self.conv3 = Conv2D(256, kernel_size = 4, strides = 2, padding = 'valid')
    self.zeropadding1 = ZeroPadding2D(3)
    self.zeropadding2 = ZeroPadding2D(1)
    self.zeropadding3 = ZeroPadding2D(1)
    self.instancenormalization1 = InstanceNormalization()
    self.instancenormalization2 = InstanceNormalization()
    self.instancenormalization3 = InstanceNormalization()
    self.activation_ReLU = ReLU()
  
  def call(self, images, labels):
    x = input_merge(images,labels)
    x = self.zeropadding1(x)
    x = self.conv1(x)
    x = self.instancenormalization1(x)
    x = self.activation_ReLU(x)
    assert x.shape == (2, 128, 128, 64) 

    x = self.zeropadding2(x)
    x = self.conv2(x)
    x = self.instancenormalization2(x)
    x = self.activation_ReLU(x)
    assert x.shape == (2, 64, 64, 128) 

    x = self.zeropadding3(x)
    x = self.conv3(x)
    x = self.instancenormalization3(x)
    x = self.activation_ReLU(x)
    assert x.shape == (2, 32, 32, 256) 
    return x
  

class ResnetIdentityBlock(tf.keras.Model):
  ''' ResentIdentityBlock for Residual part of Generator'''
  
  def __init__(self):
    super(ResnetIdentityBlock, self).__init__()
    self.conv1 = Conv2D(filters=256, kernel_size=3, strides=1, padding='valid')
    self.conv2 = Conv2D(filters=256, kernel_size=3, strides=1, padding='valid')
    self.zeropadding = ZeroPadding2D(1)
    self.instancenormalization = InstanceNormalization()
    self.activation_ReLU = ReLU()

  def call(self, input_tensor):
    x = self.zeropadding(input_tensor)
    x = self.conv1(x)
    x = self.instancenormalization(x)
    x = self.activation_ReLU(x)
    x = self.zeropadding(x)
    x = self.conv2(x)
    x += input_tensor
    x = self.activation_ReLU(x)
    return x

  
class Bottleneck_Part(tf.keras.Model):
  ''' Bottleneck part of Generator'''
  def __init__(self):
    super(Bottleneck_Part, self).__init__()
    self.ResnetIdentityBlock = ResnetIdentityBlock()
    
  def call(self, input_tensor):
    x  = self.ResnetIdentityBlock(input_tensor)
    x  = self.ResnetIdentityBlock(x)
    x  = self.ResnetIdentityBlock(x)
    x  = self.ResnetIdentityBlock(x)
    x  = self.ResnetIdentityBlock(x)
    x  = self.ResnetIdentityBlock(x)       
    return x
  

class Upsampling_Part(tf.keras.Model):
  ''' Upsampling part of Generator'''
  
  def __init__(self):
    super(Upsampling_Part, self).__init__()
    self.deconv1 = Conv2DTranspose(128, kernel_size = 4, strides = 2, padding = 'same')
    self.deconv2 = Conv2DTranspose(64, kernel_size = 4, strides = 2, padding = 'same')
    self.conv1 = Conv2D(3, kernel_size = 7, strides = 1, padding = 'same')
    self.zeropadding1 = ZeroPadding2D(1)
    self.zeropadding2 = ZeroPadding2D(1)
    self.zeropadding3 = ZeroPadding2D(3)
    self.activation_ReLU = ReLU()
    self.instancenormalization1 = InstanceNormalization()
    self.instancenormalization2 = InstanceNormalization()

    
  def call(self, x):
#     x = self.zeropadding1(x)
    x = self.deconv1(x)
    x = self.instancenormalization1(x)
    x = self.activation_ReLU(x)
    print(x.shape)
    assert x.shape == (2, 64, 64, 128) 

#     x = self.zeropadding2(x)
    x = self.deconv2(x)
    x = self.instancenormalization2(x)
    x = self.activation_ReLU(x)
    assert x.shape == (2, 128, 128, 64) 

#     x = self.zeropadding3(x)
    x = self.conv1(x)
    x = tanh(x)
    assert x.shape == (2, 128, 128, 3) 
    return x
  
  
class Build_generator(tf.keras.Model):
  ''' Building a generator'''
  def __init__(self):
    super(Build_generator, self).__init__()
    self.Downsampling = Downsampling_Part()
    self.ResidualBlock = Bottleneck_Part()
    self.Upsampling = Upsampling_Part()
    
  def call(self, images, labels):
    x = self.Downsampling(images, labels)
    print("The shape of the input tensor after downsampling :", x.shape)
    x = self.ResidualBlock(x)
    print("The shape of the input tensor after Bottleneck :", x.shape)
    x = self.Upsampling(x)
    print("The shape of the input tensor after upsampling :", x.shape)
    return x
 

class Build_discriminator(tf.keras.Model):
  ''' Building a discriminator'''
  
  def __init__(self, image_size, nd):
    super(Build_discriminator, self).__init__()
    self.conv1 = Conv2D(64, kernel_size = 4, strides = 2, padding = 'valid')
    self.conv2 = Conv2D(128, kernel_size = 4, strides = 2, padding = 'valid')
    self.conv3 = Conv2D(256, kernel_size = 4, strides = 2, padding = 'valid')
    self.conv4 = Conv2D(512, kernel_size = 4, strides = 2, padding = 'valid')
    self.conv5 = Conv2D(1024, kernel_size = 4, strides = 2, padding = 'valid')
    self.conv6 = Conv2D(2048, kernel_size = 4, strides = 2, padding = 'valid')    
    self.conv7_1 = Conv2D(1, kernel_size = 3, strides = 1)    
    self.conv7_2 = Conv2D(nd, kernel_size = int(image_size/64), strides = 1)    
    self.zeropadding0 = ZeroPadding2D(0) 
    self.zeropadding = ZeroPadding2D(1)
    self.activation_LeakyReLU = LeakyReLU(alpha=0.01)
    
  def call(self, x):
    x = self.zeropadding(x)
    x = self.conv1(x)
    x = self.activation_LeakyReLU(x)
    assert x.shape == (2, 64, 64, 64) 
    print("The shape of the input tensor after Input layer :", x.shape)

    x = self.zeropadding(x)
    x = self.conv2(x)
    x = self.activation_LeakyReLU(x)
    assert x.shape == (2, 32, 32, 128) 
    
    x = self.zeropadding(x)
    x = self.conv3(x)
    x = self.activation_LeakyReLU(x)
    assert x.shape == (2, 16, 16, 256) 
    
    x = self.zeropadding(x)
    x = self.conv4(x)
    x = self.activation_LeakyReLU(x)
    assert x.shape == (2, 8, 8, 512) 

    x = self.zeropadding(x)
    x = self.conv5(x)
    x = self.activation_LeakyReLU(x)
    assert x.shape == (2, 4, 4, 1024) 

    x = self.zeropadding(x)
    x = self.conv6(x)
    x = self.activation_LeakyReLU(x)
    assert x.shape == (2, 2, 2, 2048) 
    print("The shape of the input tensor after Hidden layer :", x.shape)
    
    x_src = self.zeropadding(x)
    D_src = self.conv7_1(x_src)
    D_cls = self.conv7_2(x)
    assert D_src.shape == (2, 2, 2, 1)
    assert D_cls.shape == (2, 1, 1, 5) 
    print("The shape of the input tensor after Output layer :", D_src.shape, D_cls.shape)
    
    return  D_src, D_cls
  
 

## Loss

In [0]:
def adverserial_loss(logits, real=True):
  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)  
  if real == True:
    loss = cross_entropy(tf.ones_like(logits),logits)
  else:
    loss = cross_entropy(tf.zeros_like(logits),logits)
    return loss

def reconstruction_loss(image,rec_image,lambda_rec):
  return lambda_rec * np.abs(tf.reduce_mean(image  - rec_image))
 
def domain_cls_loss(domain, logits):
  cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)  
  return lambda_cls * cross_entropy(domain, logits)

## G_loss, D_loss
def G_loss(fake_D_src, target_D_cls, target_domain,input_image, reconstructed_image, lambda_cls,lambda_rec):
  loss = adverserial_loss(fake_D_src, real=True) + lambda_cls * domain_cls_loss(target_domain, target_D_cls) + lambda_rec * reconstruction_loss(input_image, reconstructed_image)
  return loss
  
def D_loss(real_D_src, src_logits, original_domain, original_D_cls, lambda_cls):
  loss = -1 * (adverserial_loss(real_D_src, real=True) + adverserial_loss(fake_D_src, real=False)) + lambda_cls* domain_cls_loss(original_domain,original_D_cls)
  return loss
  

## Optimizer

In [0]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

## Model checkpoint

## Train loop 

In [0]:
def train_step(input_image, original_domain, target_domain):
# '''fake_image (image) # generator: generated fake images
#  fake_D_src (logits) # discriminator: real / fake image classification for fake image
#  target_D_cls (logits) # discriminator: original / target label classification for fake image
#  real_D_src (logits) # discriminator: real / fake image classificiaion for real image
#  original_D_cls (logits) # discriminator: original / target label classificiaion for real image
#  fake_D_src (logits) # discriminator: original/target classification for fake image
#  reconstructed_image (image) # generator: generated real (reconstructed) images '''

  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        
    # generator + discriminator combined
    fake_image = generator(input_image, target_domain)  # step(b)
    fake_D_src, target_D_cls = discriminator(fake_image)  # step(d) # 우선 fake image를 넣어서 보조 classification을 학습
    reconstructed_image = generator(fake_image, original_domain) # step(c)

    # discriminator
    real_D_src, original_D_cls = discriminator(input_image) #step(a) 
    fake_D_src, fake_D_cls = discriminator(fake_image) #step(a)  # 여기서는 보조 classification 학습 안함
    
    generator_loss = G_loss(fake_D_src, target_D_cls, target_domain,input_image, reconstructed_image, lambda_cls,lambda_rec)
    discriminator_loss = D_loss(real_D_src, src_logits, original_domain, original_D_cls, lambda_cls)
    
    gradients_of_generator = gen_tape.gradient(generator_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(discriminator_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    
def train(train_dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for input_image, original_domain in train_dataset:
            
      target_domain = random_target_domain_generation()
      train_step(input_image, original_domain, target_domain)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)
  

## Execution

In [0]:
### original data generation
N = 10
batch_size = 2
image_size = 128
image_data = np.random.normal(size=[N, image_size, image_size, 3])
domain_1 = np.random.uniform(low=0., high=3., size=N).astype(np.int32) # facial expression attributes
domain_1 = tf.one_hot(domain_1, depth=3)
domain_2 = np.random.randint(2, size=(N,2)) # male, young attributes
domain =  np.concatenate([domain_1, domain_2], axis=-1)
domain = domain.reshape((-1,1,1, domain.shape[1]))
nd = domain.shape[-1] 

train_dataset = tf.data.Dataset.from_tensor_slices((image_data,domain))
train_dataset = train_dataset.batch(batch_size) 

### Build generator & discriminator
generator = Build_generator()
discriminator = Build_discriminator(image_size, nd)

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

### Loop information
epochs = 1
num_examples_to_generate = 16
seed = tf.random.normal([num_examples_to_generate])
lambda_cls = 0.1
lambda_rec = 0.5

# ### Train
# train(train_dataset, epochs)



In [0]:
# for input_image, original_domain in train_dataset.take(1):    
#   target_domain = random_target_domain_generation()
# #   train_step(input_image, original_domain, target_domain)
#   fake_image = generator(input_image, target_domain)  # step(b)
#   fake_D_src, target_D_cls = discriminator(fake_image)  # step(d) # 우선 fake image를 넣어서 보조 classification을 학습
#   print(fake_D_src, target_D_cls)
#   reconstructed_image = generator(fake_image, original_domain) # step(c)


The shape of the input tensor after downsampling : (2, 32, 32, 256)
The shape of the input tensor after Bottleneck : (2, 32, 32, 256)
(2, 64, 64, 128)
The shape of the input tensor after upsampling : (2, 128, 128, 3)
The shape of the input tensor after Input layer : (2, 64, 64, 64)
The shape of the input tensor after Hidden layer : (2, 2, 2, 2048)
The shape of the input tensor after Output layer : 

In [10]:
# a,b = discriminator(fake_image)
# a.shape
# b.shape

The shape of the input tensor after Input layer : (2, 64, 64, 64)
The shape of the input tensor after Hidden layer : (2, 2, 2, 2048)
The shape of the input tensor after Output layer : (2, 2, 2, 1) (2, 1, 1, 5)
