# Install gdown Python package.

In [0]:
!pip install -U --no-cache-dir gdown

# Install Tensorflow-Addons.

In [0]:
!pip install tensorflow-addons

# Install Tensorflow-datasets.

In [0]:
!pip install tensorflow-datasets

# Use TensorFlow 2.x.

In [0]:
try:
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
import tensorflow.keras.layers as layers

import numpy as np
np.random.seed(7)

print(tf.__version__)

# Configuration parameters.

In [0]:
batch_size = 32
number_of_attributes = 40
image_shape = (128, 128, 3)
adversarial_loss_mode = 'wgan'

gradient_penalty_mode = '1-gp'
gradient_penalty_sample_mode = 'line'

d_gradient_penalty_weight = 10.0
d_attribute_loss_weight = 1.0
g_attribute_loss_weight = 10.0
g_reconstruction_loss_weight = 100.0

# Load CelebA dataset.

In [0]:
builder = tfds.builder('celeb_a')
print(builder.info)

### Download dataset from dataset Google drive.

In [0]:
builder.download_and_prepare()

### Download dataset from Google drive.

In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
!cp -r '/content/drive/My Drive/datasets/tensorflow_datasets' /root/.

### View dataset contents.

In [0]:
!ls -al /root/tensorflow_datasets/celeb_a/

### Create CelebA dataset splits.
* train
* validation
* test

In [0]:
celeba_datasets = builder.as_dataset()
print(celeba_datasets)

In [0]:
train_dataset = celeba_datasets['train']
train_dataset = train_dataset.batch(batch_size)

In [0]:
val_dataset = celeba_datasets['validation']
val_dataset = val_dataset.batch(batch_size)

In [0]:
test_dataset = celeba_datasets['test']
test_dataset = test_dataset.batch(batch_size)

# Create the optimizer.

*   Adam optimizer
*   Learning rate = 0.0002
*   β1 = 0.5
*   β2 = 0.999

In [0]:
optimizer = tf.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.999)

# Create AttGAN encoder model.

In [0]:
class UNetGenc(layers.Layer):

  def __init__(self, dimension=64, downsamplings_layers=5):
    super(UNetGenc, self).__init__()

    self._dimension = 64
    self._downsamplings_layers = 5

  def call(self, inputs):

    input_layer = inputs
    output_units = self._dimension
    
    output_layers = []
    for layer_index in range(self._downsamplings_layers):
      input_layer = layers.Conv2D(output_units, (4,4), strides=(2,2), padding='same')(input_layer)
      input_layer = layers.BatchNormalization()(input_layer)
      input_layer = layers.LeakyReLU(alpha=0.2)(input_layer)

      output_layers.append(input_layer)
      output_units = output_units * 2    
    
    return(output_layers)

In [0]:
input_image = np.random.rand(batch_size, image_shape[0], image_shape[1], image_shape[2])
encoder = UNetGenc()
output = encoder(input_image)
print('number of layers',len(output))
for layer in output:
  print('layer shape', layer.shape)

# Concatenate features and attributes.
*   Tile all elements of attributes.
*   Concat features + attributes along the channel axis.
*   features shape - (N, H, W, C_a)
*   attributes shape - (N, 1, 1, C_b) or (N, C_b)

In [0]:
def concatenate(list_of_features, list_of_attributes=[]):
  list_of_features = list(list_of_features) if isinstance(list_of_features, (list, tuple)) else [list_of_features]
  list_of_attributes = list(list_of_attributes) if isinstance(list_of_attributes, (list, tuple)) else [list_of_attributes]
  for index, attributes in enumerate(list_of_attributes):
        attributes = tf.reshape(attributes, [-1, 1, 1, attributes.shape[-1]])
        attributes = tf.tile(attributes, [1, list_of_features[0].shape[1], list_of_features[0].shape[2], 1])
        list_of_attributes[index] = attributes
  return tf.concat(list_of_features + list_of_attributes, axis=-1)

# Create AttGAN decoder model.

In [0]:
class UNetGdec(layers.Layer):

  def __init__(self, dimension=64, upsamplings_layers=5, shortcut_layers=1, inject_layers=1):
    super(UNetGdec, self).__init__()

    self._dimension = 64
    self._upsamplings_layers = 5
    self._shortcut_layers = shortcut_layers
    self._inject_layers = inject_layers

  def call(self, inputs):
    features, attributes = inputs

    #attributes = tensorflow.to_float(attributes)    
    output_units = self._dimension

    input_layer = concatenate(features[-1], attributes)
    for layer_index in range(self._upsamplings_layers - 1):
      input_layer = layers.Conv2DTranspose(output_units, (4, 4), strides=(2,2), padding='same')(input_layer)
      input_layer = layers.BatchNormalization()(input_layer)
      input_layer = layers.LeakyReLU(alpha=0.2)(input_layer)

      if (self._shortcut_layers > layer_index):
        input_layer = concatenate([input_layer, features[-2 - layer_index]])

      if (self._inject_layers > layer_index):
        input_layer = concatenate(input_layer, attributes)

      output_units = output_units * 2

    input_layer = layers.Conv2DTranspose(3, (4, 4), strides=(2,2), padding='same')(input_layer)
    input_layer = tf.keras.activations.tanh(input_layer) 

    output_layer = input_layer

    return(output_layer)  

In [0]:
input_image = np.random.rand(batch_size, image_shape[0], image_shape[1], image_shape[2])
attributes = np.random.rand(batch_size, number_of_attributes)

encoder = UNetGenc()
encoded_input = encoder(input_image)

decoder = UNetGdec()
decoded_output = decoder([encoded_input, attributes])
print(decoded_output.shape)

# Create AttGAN discriminator / classification model.

In [0]:
import tensorflow_addons as tfa

In [0]:
class Discriminator(layers.Layer):

  def __init__(self, number_of_attributes=40, dimension=64, dense_dimension=1024, downsamplings_layers=5): 
    super(Discriminator, self).__init__()  

    self._number_of_attributes = number_of_attributes
    self._dimension = dimension
    self._dense_dimension = dense_dimension
    self._downsamplings_layers = downsamplings_layers

  def call(self, inputs):

      input_layer = inputs  
      output_units = self._dimension  

      for layer_index in range(self._downsamplings_layers): 
        input_layer = layers.Conv2D(output_units, (4,4), strides=(2,2), padding='same')(input_layer)         
        input_layer = tfa.layers.InstanceNormalization()(input_layer)
        input_layer = layers.LeakyReLU(alpha=0.2)(input_layer)

        output_units = output_units * 2

      input_layer = layers.Flatten()(input_layer)

      discriminator_output = layers.Dense(self._dense_dimension)(input_layer) 
      discriminator_output = tfa.layers.InstanceNormalization()(discriminator_output)     
      discriminator_output = layers.LeakyReLU(alpha=0.2)(discriminator_output)
      discriminator_output = layers.Dense(1)(discriminator_output)
      
      attribute_output = layers.Dense(self._dense_dimension)(input_layer)  
      attribute_output = tfa.layers.InstanceNormalization()(attribute_output)         
      attribute_output = layers.LeakyReLU(alpha=0.2)(attribute_output)
      attribute_output = layers.Dense(self._number_of_attributes, activation='sigmoid')(attribute_output)

      return([discriminator_output, attribute_output])

In [0]:
input_image = np.random.rand(batch_size, image_shape[0], image_shape[1], image_shape[2])

discriminator = Discriminator(number_of_attributes)
discriminator_prediction, attribute_prediction = discriminator(input_image)

print('discriminator prediction shape', discriminator_prediction.shape)
print('attribute prediction shape', attribute_prediction.shape)

# Create adversarial loss functions.
*   Generator loss function
*   Discriminator loss function

## WGAN loss functions.
*   Generator loss function
*   Discriminator loss function

In [0]:
def wgan_loss_functions():
    def discriminator_loss_function(real_logit, fake_logit):
        real_loss = - tf.reduce_mean(real_logit)
        fake_loss = tf.reduce_mean(fake_logit)
        return(real_loss, fake_loss)

    def generator_loss_function(fake_logit):
        fake_loss = - tf.reduce_mean(fake_logit)
        return(fake_loss)

    return(discriminator_loss_function, generator_loss_function)

In [0]:
def adversarial_loss_functions(adversarial_loss_mode):
  if(adversarial_loss_mode == 'wgan'):
    return(wgan_loss_functions())
  else:
    return(wgan_loss_functions())

# Create different models and loss functions.
* Encoder model
* Decoder model
* Discriminator model
* Discriminator loss function
* Generator loss function

In [0]:
encoder = UNetGenc()
decoder = UNetGdec()
discriminator = Discriminator(number_of_attributes)

discriminator_loss_function, generator_loss_function = adversarial_loss_functions(adversarial_loss_mode)

# Create composite generator model.

In [0]:
def create_composite_generator(image_shape, number_of_attributes):

  input_image = layers.Input(shape=image_shape)
  input_attributes = layers.Input(shape=(number_of_attributes))

  b = tf.random.shuffle(input_attributes)

  a_ = input_attributes * 2 - 1
  b_ = b * 2 - 1

  # Generate
  z = encoder(input_image)
  xa_ = decoder([z, a_])
  xb_ = decoder([z, b_])

  # Discriminate
  xb__logit_gan, xb__logit_att = discriminator(xb_)

  xb__loss_gan = generator_loss_function(xb__logit_gan)
  xb__loss_att = tf.keras.losses.categorical_crossentropy(b, xb__logit_att)
  xa__loss_rec = tf.keras.losses.mae(input_image, xa_)

  loss = (xb__loss_gan + 
            xb__loss_att * g_attribute_loss_weight +
            xa__loss_rec * g_reconstruction_loss_weight)

  composite_model = tf.keras.models.Model([input_image, input_attributes], [loss])
  composite_model.compile(optimizer=optimizer)

  return(composite_model)

In [0]:
composite_generator = create_composite_generator(image_shape, number_of_attributes)

# Create composite discriminator model.

In [0]:
def create_composite_discriminator(image_shape, number_of_attributes):

  input_image = layers.Input(shape=image_shape)
  input_attributes = layers.Input(shape=(number_of_attributes))

  b = tf.random.shuffle(input_attributes)

  a_ = input_attributes * 2 - 1
  b_ = b * 2 - 1

  # Generate
  z = encoder(input_image)
  xa_ = decoder([z, a_])
  xb_ = decoder([z, b_])

  # Discriminate
  xb__logit_gan, xb__logit_att = discriminator(xb_)

  xb__loss_gan = generator_loss_function(xb__logit_gan)
  xb__loss_att = tf.keras.losses.categorical_crossentropy(b, xb__logit_att)
  xa__loss_rec = tf.keras.losses.mae(input_image, xa_)

  loss = (xb__loss_gan + 
            xb__loss_att * g_attribute_loss_weight +
            xa__loss_rec * g_reconstruction_loss_weight)

  composite_model = tf.keras.models.Model([input_image, input_attributes], [loss])
  composite_model.compile(optimizer=optimizer)

  return(composite_model)

# Train the model.

In [0]:
for dataset_batch in train_dataset:
  print(dataset_batch)