# AttGAN

# Abstract - IEEE Xplore
Facial attribute editing aims to manipulate single or multiple attributes on a given face image, i.e., to generate a new face image with desired attributes while preserving other details. Recently, the generative adversarial net (GAN) and encoder-decoder architecture are usually incorporated to handle this task with promising results. Based on the encoder-decoder architecture, facial attribute editing is achieved by decoding the latent representation of a given face conditioned on the desired attributes. Some existing methods attempt to establish an attribute-independent latent representation for further attribute editing. However, such attribute-independent constraint on the latent representation is excessive because it restricts the capacity of the latent representation and may result in information loss, leading to over-smooth or distorted generation. Instead of imposing constraints on the latent representation, in this work, we propose to apply an attribute classification constraint to the generated image to just guarantee the correct change of desired attributes, i.e., to change what you want. Meanwhile, the reconstruction learning is introduced to preserve attribute-excluding details, in other words, to only change what you want. Besides, the adversarial learning is employed for visually realistic editing. These three components cooperate with each other forming an effective framework for high quality facial attribute editing, referred as AttGAN. Furthermore, the proposed method is extended for attribute style manipulation in an unsupervised manner. Experiments on two wild datasets, CelebA and LFW, show that the proposed method outperforms the state-of-the-art on realistic attribute editing with other facial details well preserved.

# References
* [AttGAN: Facial Attribute Editing by Only Changing What You Want - IEEE Xplore ](https://ieeexplore.ieee.org/document/8718508)
* [AttGAN: Facial Attribute Editing by Only Changing What You Want - arXiv.org](https://arxiv.org/abs/1711.10678)
* [AttGAN-Tensorflow](https://github.com/LynnHo/AttGAN-Tensorflow)
* [AttGAN-PyTorch](https://github.com/elvisyjlin/AttGAN-PyTorch)



# Prerequisite
* Align CelebA dataset images using [align_images](https://github.com/look4pritam/TensorFlowExamples/blob/master/GAN/CelebA/align_images.ipynb) script.
* Download the preprocessed CelebA dataset using this [link](https://drive.google.com/file/d/1diaLDdB-dNMsPhJX0uco4155ghi4KMXK/view?usp=sharing).

# Import TensorFlow 2.x.

In [0]:
try:
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
tf.random.set_seed(7)

import tensorflow.keras.layers as layers
import tensorflow.keras.models as models

import numpy as np
np.random.seed(7)

import matplotlib.pyplot as plot

print(tf.__version__)

# Set the root directory.

In [0]:
import os

root_dir = '/content/'
os.chdir(root_dir)

!ls -al

# Download aligned CelebA dataset from goolge drive.

### Mount google drive.

In [0]:
from google.colab import drive
drive.mount('/content/drive')

### Copy aligned CelebA dataset from gdrive.

In [0]:
!ls -al '/content/drive/My Drive/CelebA/img_align_celeba.tar.gz'

In [0]:
!cp '/content/drive/My Drive/CelebA/img_align_celeba.tar.gz' .

In [0]:
!ls -al

### OR 

### Copy aligned CelebA dataset using google drive shared link.

In [0]:
!gdown --id 1z5bpHVrciXmE8obe8NeYa3u9zWvdLWwS

In [0]:
!ls -al

### OR

In [0]:
!gdown --id 1diaLDdB-dNMsPhJX0uco4155ghi4KMXK

In [0]:
!ls -al

### Extract the dataset.

In [0]:
!tar -xzf img_align_celeba.tar.gz

In [0]:
#!mv img_align_celeba.tar.gz '/content/drive/My Drive/CelebA/.'

In [0]:
!rm -rf img_align_celeba.tar.gz

### Verify dataset contents.

In [0]:
!ls -al
!ls -al img_align_celeba
!ls -l img_align_celeba/images | wc -l

# Create dictionaries for facial attributes.
* Attributes to identifiers
* Identifiers to attributes

In [0]:
attributes_to_identifiers = {
    '5_o_Clock_Shadow': 0, 
    'Arched_Eyebrows': 1, 
    'Attractive': 2,       
    'Bags_Under_Eyes': 3,           
    'Bald': 4, 
    'Bangs': 5, 
    'Big_Lips': 6,           
    'Big_Nose': 7, 
    'Black_Hair': 8, 
    'Blond_Hair': 9, 
    'Blurry': 10,           
    'Brown_Hair': 11, 
    'Bushy_Eyebrows': 12, 
    'Chubby': 13,           
    'Double_Chin': 14, 
    'Eyeglasses': 15, 
    'Goatee': 16, 
    'Gray_Hair': 17, 
    'Heavy_Makeup': 18, 
    'High_Cheekbones': 19,          
    'Male': 20, 
    'Mouth_Slightly_Open': 21, 
    'Mustache': 22, 
    'Narrow_Eyes': 23, 
    'No_Beard': 24, 
    'Oval_Face': 25,           
    'Pale_Skin': 26, 
    'Pointy_Nose': 27, 
    'Receding_Hairline': 28,           
    'Rosy_Cheeks': 29, 
    'Sideburns': 30, 
    'Smiling': 31,           
    'Straight_Hair': 32, 
    'Wavy_Hair': 33, 
    'Wearing_Earrings': 34,           
    'Wearing_Hat': 35, 
    'Wearing_Lipstick': 36,           
    'Wearing_Necklace': 37, 
    'Wearing_Necktie': 38, 
    'Young': 39
    }

In [0]:
identifiers_to_attributes = {v: k for k, v in attributes_to_identifiers.items()}

# Prepare CelebA dataset in TensorFlow dataset format.


In [0]:
image_root_dir = 'img_align_celeba/images'

In [0]:
train_label_filename = 'img_align_celeba/train_label.txt'
val_label_filename = 'img_align_celeba/val_label.txt'
test_label_filename = 'img_align_celeba/test_label.txt'

In [0]:
batch_size = 32

In [0]:
def create_celeba_dataset(image_root_dir, attribute_filename):
  image_names = np.genfromtxt(attribute_filename, dtype=str, usecols=0)
  image_filename_array = np.array([os.path.join(image_root_dir, image_name) for image_name in image_names])

  attributes_array = np.genfromtxt(attribute_filename, dtype=float, usecols=range(1, 41))  

  number_of_batches = len(image_filename_array) // batch_size  

  memory_data = (image_filename_array, attributes_array)  
  dataset = tf.data.Dataset.from_tensor_slices(memory_data)
  
  return(dataset, number_of_batches)

# Preprocess the dataset.

In [0]:
default_attribute_names = [
 'Bald', 
 'Bangs', 
 'Black_Hair', 
 'Blond_Hair', 
 'Brown_Hair', 
 'Bushy_Eyebrows', 
 'Eyeglasses', 
 'Male', 
 'Mouth_Slightly_Open', 
 'Mustache', 
 'No_Beard', 
 'Pale_Skin', 
 'Young'
 ]

In [0]:
number_of_attributes = len(default_attribute_names)

In [0]:
image_load_shape = (143, 143, 3)
image_shape = (128, 128, 3)

In [0]:
buffer_size = 512

### Load image using an filename.

In [0]:
def load_image(image_filename):
  input_image = tf.io.read_file(image_filename)
  input_image = tf.image.decode_jpeg(input_image, 3)
  return(input_image)

### Normalize image to [-1, 1].

In [0]:
def normalize_image(image):
  image = tf.cast(image, tf.float32)
  image = tf.clip_by_value(image, 0, 255) / 127.5 - 1
  return(image)

### Random crop image.

In [0]:
def random_crop(image):
  cropped_image = tf.image.random_crop(image, size=image_shape)
  return(cropped_image)

### Apply random jitter to input image.

In [0]:
def random_jitter(image):  
  image = tf.image.resize(image, [image_load_shape[0], image_load_shape[1]])  
  image = random_crop(image)
  image = tf.image.random_flip_left_right(image)
  return(image)

### Define preprocessing for train dataset split.

In [0]:
def preprocess_attributes(attributes_array):
  selected_attributes = []
  for attribute_name in default_attribute_names:
    index = attributes_to_identifiers[attribute_name]
    selected_attributes.append(attributes_array[index])

  selected_attributes = tf.convert_to_tensor(selected_attributes)
  selected_attributes = (selected_attributes + 1) // 2 
  selected_attributes = selected_attributes * 1.   
  return(selected_attributes)

In [0]:
def preprocess_train_dataset(image_filename, attributes):  
  
  image = load_image(image_filename)
  image = random_jitter(image)
  image = normalize_image(image)

  attributes = preprocess_attributes(attributes)
  return(image, attributes)

### Define preprocessing for test dataset split.

In [0]:
def preprocess_test_dataset(image_filename, attributes):    
  
  image = load_image(image_filename)
  image = tf.image.resize(image, [image_shape[0], image_shape[1]]) 
  image = normalize_image(image)

  attributes = preprocess_attributes(attributes)
  return(image, attributes)

### Preprocess train dataset split.

In [0]:
auto_tune = tf.data.experimental.AUTOTUNE

In [0]:
train_dataset, number_of_batches = create_celeba_dataset(image_root_dir, train_label_filename)
print('number of batches -', number_of_batches)

In [0]:
train_dataset = train_dataset.map(preprocess_train_dataset, num_parallel_calls=auto_tune)
train_dataset = train_dataset.shuffle(buffer_size)
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
train_dataset = train_dataset.prefetch(auto_tune)

# Configuration parameters.

In [0]:
d_gradient_penalty_weight = 10.0
d_attribute_loss_weight = 1.0

In [0]:
g_attribute_loss_weight = 10.0
g_reconstruction_loss_weight = 100.0

In [0]:
load_previous_weights = False
save_current_weights = False

# Create the optimizer.

*   Adam optimizer
*   Learning rate = 0.0002
*   β1 = 0.5
*   β2 = 0.999

In [0]:
base_learning_rate = 0.0002

In [0]:
maximum_epochs = 60
start_decay_epoch = 30

In [0]:
def create_optimizer(base_learning_rate, maximum_epochs, start_decay_epoch, current_epoch):
  if (current_epoch >= start_decay_epoch):
     current_learning_rate = base_learning_rate * (1 - 1 / (maximum_epochs - start_decay_epoch + 1) * (current_epoch - start_decay_epoch + 1))
  else:
     current_learning_rate = base_learning_rate

  print('epochs -', current_epoch, 'learning rate -', current_learning_rate)

  optimizer = tf.optimizers.Adam(learning_rate=current_learning_rate, beta_1=0.5, beta_2=0.999)
  return(optimizer)

# Create an encoder model.

In [0]:
from functools import partial

In [0]:
class Encoder(models.Model):

  def __init__(self, encoder_dimension=64, downsamplings_layers=5, 
               name='attgan-encoder', **kwargs):
    super(Encoder, self).__init__(name=name, **kwargs)
    
    self._encoder_dimension = encoder_dimension
    self._downsamplings_layers = downsamplings_layers

    self._encoders = []    
    filters = self._encoder_dimension  
    for block_index in range(self._downsamplings_layers):
      block_name = 'block-' + str(block_index + 1)

      current_encoder = self._convolution_block(filters, 4, name=block_name)
      self._encoders.append(current_encoder)

      filters = filters * 2

  def _convolution_block(self, filters, kernel_size, 
                         activation_fn=tf.nn.leaky_relu, batch_norm=True, 
                         input_shape=None, name=None):
    
    if( input_shape is None ):
      conv = partial(layers.Conv2D)
    else:
      conv = partial(layers.Conv2D, input_shape=input_shape)

    blocks = [
      conv(filters, 
           (kernel_size, kernel_size), 
           strides = (2,2), 
           padding="same", 
           kernel_initializer = tf.keras.initializers.TruncatedNormal(stddev=0.02), 
           use_bias = True, 
           bias_initializer = tf.keras.initializers.Constant(value=0.0), 
           name='conv')
        ]

    if( batch_norm ):
      blocks.append(layers.BatchNormalization(name='bnorm'))

    if(activation_fn is not None):
      if(activation_fn == tf.nn.leaky_relu):        
        blocks.append(layers.LeakyReLU(alpha=0.2, name='act'))
      else:
        blocks.append(layers.Activation(activation_fn, name='act'))

    return(models.Sequential(blocks, name=name))

  def call(self, input_image, training=True):
    image_features = []

    layer_input = input_image
    for current_encoder in self._encoders:
      layer_input = current_encoder(layer_input, training=training)
      image_features.append(layer_input)

    return(image_features)

In [0]:
sample_images, sample_attributes = next(iter(train_dataset))

encoder_model = Encoder(encoder_dimension=64, downsamplings_layers=5)
image_features = encoder_model.predict(sample_images)
print('number of image features -', len(image_features))
for index, image_feature in enumerate(image_features):
  print('image feature -', (index +1), image_feature.shape)

encoder_model.summary()

# Concatenate features and attributes.

In [0]:
def concatenate(list_of_features, list_of_attributes, layer_name):
  list_of_features = list(list_of_features) if isinstance(list_of_features, (list, tuple)) else [list_of_features]
  list_of_attributes = list(list_of_attributes) if isinstance(list_of_attributes, (list, tuple)) else [list_of_attributes]
  for index, attributes in enumerate(list_of_attributes):
        attributes = tf.reshape(attributes, [-1, 1, 1, attributes.shape[-1]], name=layer_name + 'reshape')
        attributes = tf.tile(attributes, [1, list_of_features[0].shape[1], list_of_features[0].shape[2], 1], name=layer_name + 'tile')
        list_of_attributes[index] = attributes
  return tf.concat(list_of_features + list_of_attributes, axis=-1, name=layer_name + 'concat')

# Create a decoder layer.

In [0]:
class Decoder(models.Model):

  def __init__(self, decoder_dimension=64, upsamplings_layers=5, shortcut_layers=1, inject_layers=1,
               name='attgan-decoder', **kwargs):
    super(Decoder, self).__init__(name=name, **kwargs)
    
    self._decoder_dimension = decoder_dimension
    self._upsamplings_layers = upsamplings_layers
    self._shortcut_layers = shortcut_layers
    self._inject_layers = inject_layers

    self._decoders = []
    filters = self._decoder_dimension  
    for block_index in range(self._upsamplings_layers - 1):
      block_name = 'block-' + str(block_index + 1)

      current_decoder = self._convolution_block(filters, 4, name=block_name)
      self._decoders.append(current_decoder)

      filters = filters * 2

    current_decoder = self._convolution_block(3, 4, activation_fn=tf.nn.tanh, batch_norm=False, name='block-5')
    self._decoders.append(current_decoder)    

  def _convolution_block(self, filters, kernel_size, 
                         activation_fn=tf.nn.leaky_relu, batch_norm=True, 
                         input_shape=None, name=''):
    
    if( input_shape is None ):
      dconv = partial(layers.Conv2DTranspose)
    else:
      dconv = partial(layers.Conv2DTranspose, input_shape=input_shape)

    blocks = [
      dconv(filters, 
            (kernel_size, kernel_size), 
            strides = (2,2), 
            padding = "same", 
            kernel_initializer = tf.keras.initializers.RandomNormal(stddev=0.02),
            use_bias = True, 
            bias_initializer = tf.keras.initializers.Constant(value=0.0),
            name = 'dconv')
        ]

    if( batch_norm ):
      blocks.append(layers.BatchNormalization(name='bnorm'))

    if(activation_fn is not None):
      if(activation_fn == tf.nn.leaky_relu):        
        blocks.append(layers.LeakyReLU(alpha=0.2, name='act'))
      else:
        blocks.append(layers.Activation(activation_fn, name='act'))

    return(models.Sequential(blocks, name=name))

  def call(self, inputs, training=True):

    input_features, input_attributes = inputs    

    layer_name = 'block-0-shortcut-'    
    layer_input = concatenate(input_features[-1], input_attributes, layer_name=layer_name)    
    for block_index in range(self._upsamplings_layers):
      layer_name = 'block-' + str(block_index + 1) + '-'

      decoder_layer = self._decoders[block_index]
      layer_input = decoder_layer(layer_input, training)

      if (self._shortcut_layers > block_index):
        shortcut_name = layer_name + 'shortcut-'
        layer_input = concatenate([layer_input, input_features[-2 - block_index]], [], layer_name=shortcut_name)

      if (self._inject_layers > block_index):
        inject_name = layer_name + 'inject-'
        layer_input = concatenate(layer_input, input_attributes, layer_name=inject_name)      

    return(layer_input)

In [0]:
sample_images, sample_attributes = next(iter(train_dataset))

decoder_model = Decoder(decoder_dimension=64, upsamplings_layers=5, shortcut_layers=1, inject_layers=1)

image_features = encoder_model.predict(sample_images)
generated_images = decoder_model.predict([image_features, sample_attributes])
decoder_model.summary()
print('generated images -', generated_images.shape)

# Create a discriminator or classification model.


In [0]:
import tensorflow_addons as tfa

In [0]:
class Discriminator(models.Model):

  def __init__(self, number_of_attributes=40, 
               discriminator_dimension=64, dense_dimension=1024, downsamplings_layers=5,
               name='attgan-discriminator', **kwargs):
    super(Discriminator, self).__init__(name=name, **kwargs)
    
    self._number_of_attributes = number_of_attributes
    self._discriminator_dimension = discriminator_dimension
    self._dense_dimension = dense_dimension
    self._downsamplings_layers = downsamplings_layers
    
    self._features = None
    self._classifier = None
    self._discriminator = None

    self._create_features()
    self._create_classifier()
    self._create_discriminator()

  def _create_features(self):
    self._features = models.Sequential(name="features")
    filters = self._discriminator_dimension  
    for block_index in range(self._downsamplings_layers):
      block_name = 'block-' + str(block_index + 1)

      current_features = self._convolution_block(filters, 4, name=block_name)
      self._features.add(current_features)

      filters = filters * 2

  def _create_classifier(self):
    self._classifier = models.Sequential(name='classifier')
    self._classifier.add(self._dense_block(self._dense_dimension, activation_fn=tf.nn.leaky_relu, name='dense'))
    self._classifier.add(self._dense_block(self._number_of_attributes, activation_fn=None, name='predictions'))

  def _create_discriminator(self):
    self._discriminator = models.Sequential(name='discriminator')
    self._discriminator.add(self._dense_block(self._dense_dimension, activation_fn=tf.nn.leaky_relu, name='dense'))
    self._discriminator.add(self._dense_block(1, activation_fn=None, name='predictions'))


  def _convolution_block(self, filters, kernel_size, 
                         activation_fn=tf.nn.leaky_relu, batch_norm=True, 
                         input_shape=None, name=''):
    
    if( input_shape is None ):
      conv = partial(layers.Conv2D)
    else:
      conv = partial(layers.Conv2D, input_shape=input_shape)

    blocks = [
      conv(filters, (kernel_size, kernel_size), 
           strides = (2,2), 
           padding = "same", 
           kernel_initializer = tf.keras.initializers.RandomNormal(stddev=0.02), 
           use_bias = True, 
           bias_initializer = tf.keras.initializers.Constant(value=0.0),
           name = 'conv')
        ]

    if( batch_norm ):
      blocks.append(tfa.layers.InstanceNormalization(name='inorm'))

    if(activation_fn is not None):
      if(activation_fn == tf.nn.leaky_relu):        
        blocks.append(layers.LeakyReLU(alpha=0.2, name='act'))
      else:
        blocks.append(layers.Activation(activation_fn, name='act'))

    return(models.Sequential(blocks, name=name))

  def _dense_block(self, filters, 
                         activation_fn=tf.nn.leaky_relu, batch_norm=False, 
                         input_shape=None, name=None):
    
    if( input_shape is None ):
      dense = partial(layers.Dense)
    else:
      dense = partial(layers.Dense, input_shape=input_shape)

    blocks = [
      dense(filters, 
            kernel_initializer = tf.keras.initializers.RandomNormal(stddev=0.02), 
            use_bias = True, 
            bias_initializer = tf.keras.initializers.Constant(value=0.0), 
            name = 'dense')
        ]

    if( batch_norm ):
      blocks.append(tfa.layers.BatchNormalization(name='bnorm'))

    if(activation_fn is not None):
      if(activation_fn == tf.nn.leaky_relu):        
        blocks.append(layers.LeakyReLU(alpha=0.2, name='act'))
      else:
        blocks.append(layers.Activation(activation_fn, name='act'))

    return(models.Sequential(blocks, name=name))

  def call(self, input_image, training=True):

    layer_input = input_image   
    layer_input = self._features(layer_input, training=training)
    layer_input = layers.Flatten()(layer_input)

    classifier_predictions = self._classifier(layer_input, training=training)
    discriminator_prediction = self._discriminator(layer_input, training=training)
    
    return(discriminator_prediction, classifier_predictions)     

In [0]:
sample_images, sample_attributes = next(iter(train_dataset))

discriminator_model = Discriminator(number_of_attributes, discriminator_dimension=64, dense_dimension=1024, downsamplings_layers=5)
discriminator_prediction, classifier_predictions = discriminator_model.predict(sample_images)
discriminator_model.summary()
print('discriminator prediction -', discriminator_prediction.shape)
print('classifier predictions -', classifier_predictions.shape)

### Load previous model weights.
* Encoder model weights
* Decoder model weights
* Discriminator model weights



In [0]:
import os 

def encoder_filename():
  return('encoder.h5')

def encoder_gdrive_filename(weight_root_dir='/content/drive/My Drive/models/AttGAN/'):    
  return(os.path.join(weight_root_dir, encoder_filename()))

def decoder_filename():
  return('decoder.h5')

def decoder_gdrive_filename(weight_root_dir='/content/drive/My Drive/models/AttGAN/'):    
  return(os.path.join(weight_root_dir, decoder_filename()))

def discriminator_filename():
  return('discriminator.h5')

def discriminator_gdrive_filename(weight_root_dir='/content/drive/My Drive/models/AttGAN/'):  
  return(os.path.join(weight_root_dir, discriminator_filename()))

In [0]:
if(load_previous_weights):
  encoder_model.load_weights(encoder_gdrive_filename())
  decoder_model.load_weights(decoder_gdrive_filename())
  discriminator_model.load_weights(discriminator_gdrive_filename())

# Compute generator loss.

In [0]:
def compute_generator_loss(input_image, input_attributes):

  target_attributes = tf.random.shuffle(input_attributes)

  scaled_input_attributes = input_attributes * 2. - 1.
  scaled_target_attributes = target_attributes * 2. - 1.

  # Generator
  image_features = encoder_model(input_image, training=True)

  reconstructed_image = decoder_model([image_features, scaled_input_attributes], training=True)
  fake_image = decoder_model([image_features, scaled_target_attributes], training=True)

  # Discriminator
  fake_image_prediction, fake_image_attributes = discriminator_model(fake_image, training=False)

  fake_image_prediction_loss = tf.reduce_mean(-fake_image_prediction)
  fake_image_attributes_loss = tf.compat.v1.losses.sigmoid_cross_entropy(target_attributes, fake_image_attributes)  
  
  image_reconstruction_loss = tf.compat.v1.losses.absolute_difference(input_image, reconstructed_image)
   
  generator_loss = (  fake_image_prediction_loss 
                    + fake_image_attributes_loss * g_attribute_loss_weight 
                    + image_reconstruction_loss * g_reconstruction_loss_weight
                    )  
  '''
  print('fake_prediction', fake_image_prediction_loss.numpy(),
        'fake_attributes', fake_image_attributes_loss.numpy(),
        'image_reconstruction', image_reconstruction_loss.numpy(),
        'generator_loss', generator_loss.numpy())
  '''
  
  return(generator_loss)

In [0]:
sample_images, sample_attributes = next(iter(train_dataset))
generator_loss = compute_generator_loss(sample_images, sample_attributes)
print(generator_loss.numpy())

# Compute discriminator loss.

In [0]:
def compute_discriminator_loss(input_image, input_attributes):

  target_attributes = tf.random.shuffle(input_attributes)
  scaled_target_attributes = target_attributes * 2. - 1.

  # Generate
  image_features = encoder_model(input_image, training=False)
  fake_image = decoder_model([image_features, scaled_target_attributes], training=False)

  # Discriminate
  real_image_prediction, real_image_attributes = discriminator_model(input_image, training=True)
  fake_image_prediction, fake_image_attributes = discriminator_model(fake_image, training=True)

  # Discriminator losses
  real_image_gan_loss = tf.reduce_mean(-real_image_prediction)
  fake_image_gan_loss = tf.reduce_mean(fake_image_prediction)  

  with tf.GradientTape() as gp_tape:
                alpha = tf.random.uniform([batch_size],0.,1.,dtype=tf.float32)
                alpha = tf.reshape(alpha,(-1,1,1,1))
                sample_images = input_image + alpha * (fake_image - input_image)
                #sample_images = tf.clip_by_value(sample_images, -1., 1.)

                gp_tape.watch(sample_images)
                sample_predictions = discriminator_model(sample_images, training=False)
                sample_predictions = sample_predictions[0]
                
  gradients = gp_tape.gradient(sample_predictions, sample_images)                
  grad_l2 = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1,2,3]))
  gradient_penalty_value = tf.reduce_mean((grad_l2-1) ** 2)  

  real_image_attributes_loss = tf.compat.v1.losses.sigmoid_cross_entropy(input_attributes, real_image_attributes)  

  discriminator_loss = (  real_image_gan_loss 
                        + fake_image_gan_loss 
                        + gradient_penalty_value * d_gradient_penalty_weight 
                        + real_image_attributes_loss * d_attribute_loss_weight                        
                        )  
  '''
  print('real_gan', real_image_gan_loss.numpy(),
        'fake_gan', fake_image_gan_loss.numpy(),
        'gp_value', gradient_penalty_value.numpy(),
        'real_attributes', real_image_attributes_loss.numpy(),
        'discriminator_loss', discriminator_loss.numpy())
  '''
  return(discriminator_loss)

In [0]:
sample_images, sample_attributes = next(iter(train_dataset))
discriminator_loss = compute_discriminator_loss(sample_images, sample_attributes)
print(discriminator_loss.numpy())

# Train the model.

In [0]:
model_loss_frequency = 1000
model_save_frequency = 2000

In [0]:
def save_models():
  if(save_current_weights):  
    encoder_model.save_weights(encoder_gdrive_filename())        
    decoder_model.save_weights(decoder_gdrive_filename())      
    discriminator_model.save_weights(discriminator_gdrive_filename())  

In [0]:
image_filename = 'img_align_celeba/images/202520.jpg'
input_image = load_image(image_filename)
input_image = tf.image.resize(input_image, [image_shape[0], image_shape[1]], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 
input_image = normalize_image(input_image)
input_image = tf.expand_dims(input_image, axis=0)

In [0]:
image_attributes =[-1., -1., -1., -1., 1., -1., 1., -1., 1., -1., 1., -1., 1.]
image_attributes = tf.expand_dims(image_attributes, axis=0)

In [0]:
print('input image - ', input_image.shape)
print('image attributes -', image_attributes.shape)
plot.imshow(input_image[0])

In [0]:
image_features = encoder_model.predict(input_image)
generated_images = decoder_model.predict([image_features, image_attributes])
plot.imshow(generated_images[0])

In [0]:
start_epoch = 0

In [0]:
def train(train_dataset, maximum_epochs, start_decay_epoch, start_epoch):  

  g_optimizer = create_optimizer(base_learning_rate, maximum_epochs, start_decay_epoch, start_epoch)
  d_optimizer = create_optimizer(base_learning_rate, maximum_epochs, start_decay_epoch, start_epoch)

  for current_epoch in range(start_epoch, 5): #maximum_epochs):

    batch_index = 0    
    for dataset_batch in train_dataset:  
      batch_index = batch_index + 1    

      images, attributes = dataset_batch

      if(batch_index%6 == 0):
        with tf.GradientTape() as generator_tape:                      
          generator_loss = compute_generator_loss(images, attributes)      
          
        generator_gradients = generator_tape.gradient(generator_loss, [*encoder_model.trainable_variables, *decoder_model.trainable_variables])
        g_optimizer.apply_gradients(zip(generator_gradients, [*encoder_model.trainable_variables, *decoder_model.trainable_variables]))  

      else:
        with tf.GradientTape() as discriminator_tape:              
          discriminator_loss = compute_discriminator_loss(images, attributes)  

        discriminator_gradients = discriminator_tape.gradient(discriminator_loss, discriminator_model.trainable_variables)
        d_optimizer.apply_gradients(zip(discriminator_gradients, discriminator_model.trainable_variables))        

      if(batch_index%model_loss_frequency == 0):
        print('epoch -', current_epoch, 'generator loss -', generator_loss.numpy(), 'discriminator loss -', discriminator_loss.numpy())
        image_features = encoder_model.predict(input_image)
        generated_images = decoder_model.predict([image_features, image_attributes])   
        generated_images = (generated_images + 1.) / 2. * 255.     
        generated_images = generated_images.astype(np.uint8)
        plot.imshow(generated_images[0])
        plot.show()

In [0]:
train(train_dataset, maximum_epochs, start_decay_epoch, start_epoch)