## 1 Setup

In [0]:
!pip install -U -q PyDrive
import os
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

#### set up local working directories

In [2]:
# get current working directory
cwd = os.getcwd()

# choose a local (colab) directory to store the data.
local_root_dir = os.path.expanduser('{}/SAGAN/'.format(cwd))
try:
  # make root directory
  os.makedirs(local_root_dir)
  # create sub directories
  for subdir in ['dataset', 'trained_models']:
    os.makedirs(os.path.join(local_root_dir, subdir))
    print('created ',os.path.join(local_root_dir, subdir))
except: pass

created  /content/SAGAN/dataset
created  /content/SAGAN/trained_models


#### download dataset hosted in G-drive

In [0]:
# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [4]:
# G-Drive shareId of the dataset folder
fileId = '1Pn8C73WWQ6tLGvfg2_GGQwIXczKj9td4'
file_list = drive.ListFile(
    {'q': "'{}' in parents".format('1Pn8C73WWQ6tLGvfg2_GGQwIXczKj9td4')}).GetList()
for f in file_list:
  print(f['title'], f['id'])

celeba_dataset.tfrecord 1Xs7JtH1ChBVu0uWBYGOnBYaqbbtierH_
celeba_dataset_highres.tfrecord 127qyHBbUxXdkDcPa92TgYI4e5zk63EHR


In [0]:
# download from G-Drive to local dataset directory
fname = os.path.join(local_root_dir+'dataset/', file_list[0]['title'])
f_ = drive.CreateFile({'id': file_list[0]['id']})
f_.GetContentFile(fname)

In [6]:
# check tfrecord was downloaded correctly
!ls /content/SAGAN/dataset/

celeba_dataset.tfrecord


In [7]:
# mount G-drive
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


## 2 Import Dependencies

In [0]:
import tensorflow as tf
import numpy as np
import tensorflow as tf
import math
import os
import sys
import glob
import matplotlib.pyplot as plt

from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import tag_constants, signature_constants
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def

tf.enable_eager_execution()
tfe = tf.contrib.eager

## 3 Input Data Pipeline


#### parse tf example

In [0]:
def parse_img_example(record, target_height=128, target_width=128):
    """
    function to parse tfRecord examples back into tensors.
    """
    keys_to_features = {
        "image" : tf.FixedLenFeature((), tf.string),
        "height": tf.FixedLenFeature((), tf.int64),
        "width" : tf.FixedLenFeature((), tf.int64)
    }    
    features = tf.parse_single_example(record, keys_to_features)
    # convert features to tensors
    image = tf.decode_raw(features['image'], tf.uint8)
    height = tf.cast(features['height'], tf.int64)
    width = tf.cast(features['width'], tf.int64)
    # reshape input to original dimensions and cast image to type float
    image = tf.reshape(image, (height, width, 3))
    # reshape images via center crop and pad to same shape
    image = tf.image.resize_image_with_crop_or_pad(image, target_height, target_width)
    return image

#### normalize image tensor

In [0]:
def normalizer(image, dtype):
    # normalize image pixel values to within [-1,1]
    image = tf.cast(image, dtype=dtype) / 128.0 - 1.0
    # noise addition normalization
    image += tf.random_uniform(shape=tf.shape(image), minval=0., maxval=1./128., dtype=dtype)
    return image

#### construct tf dataset

In [0]:
def create_tfdataset(tfrecord_file, img_height, img_width, shuffle_buffer, epochs, batch_size, pThreads=4):
    # create a tf dataset obj from TFRecord file
    dataset = tf.data.TFRecordDataset(tfrecord_file)
    # use dataset.map() in conjunction with the parse_exmp function to 
    # de-serialize each example record in TFRecord file
    dataset = dataset.map(lambda img_exmpl: parse_img_example(img_exmpl, img_height, img_width), num_parallel_calls=pThreads)
    # normalize image
    dataset = dataset.map(lambda image: normalizer(image, dtype=tf.float32), num_parallel_calls=pThreads)
    # configure dataset epoch, shuffle, padding and batching operations
    dataset = dataset.shuffle(shuffle_buffer).repeat(epochs).batch(batch_size)
    return dataset

## 4 Generator & Discriminator Models


#### power iteration function for applying spectral normalization

In [0]:
import tensorflow.keras.initializers as initializers
import tensorflow.keras.constraints as constraints
import tensorflow.keras.activations as activations
import tensorflow.keras.regularizers as regularizers
from tensorflow.keras.layers import Layer
import tensorflow.keras.backend as K
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.layers import utils
from tensorflow.python.ops import nn
from tensorflow.python.layers import core
from tensorflow.python.layers import convolutional as tf_convolutional_layers
from tensorflow.python.util.tf_export import tf_export


def _l2normalizer(v, epsilon=1e-12):
    return v / (K.sum(v ** 2) ** 0.5 + epsilon)


def power_iteration(W, u, rounds=1):
    '''
    Accroding the paper, we only need to do power iteration one time.
    '''
    _u = u

    for i in range(rounds):
        _v = _l2normalizer(K.dot(_u, W))
        _u = _l2normalizer(K.dot(_v, K.transpose(W)))

    W_sn = K.sum(K.dot(_u, W) * _v)
    return W_sn, _u, _v


#### custom dense layer with spectral norm applied

In [0]:
# dense layer with spectral norm toggle
@tf_export('keras.layers.Dense')
class Dense(core.Dense):
    def __init__(self, 
                 units,
                 activation=None,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 spectral_normalization=True,
                 **kwargs):
        
        # initializing parent class
        super(Dense, self).__init__(
            units=int(units), 
            activation=activations.get(activation),
            use_bias=use_bias,
            kernel_initializer=initializers.get(kernel_initializer),
            bias_initializer=initializers.get(bias_initializer),
            kernel_regularizer=regularizers.get(kernel_regularizer),
            bias_regularizer=regularizers.get(bias_regularizer),
            activity_regularizer=regularizers.get(activity_regularizer),
            kernel_constraint=constraints.get(kernel_constraint),
            bias_constraint=constraints.get(bias_constraint),
            **kwargs)

        self.u = K.random_normal_variable([1, units], 0, 1, dtype=self.dtype, name="sn_estimate")  # [1, out_channels]
        self.spectral_normalization = spectral_normalization

    def compute_spectral_normal(self, training=True):
        # Spectrally Normalized Weight
        if self.spectral_normalization:
            # Get kernel tensor shape [batch, units]
            W_shape = self.kernel.shape.as_list()

            # Flatten the Tensor
            W_mat = K.reshape(self.kernel, [W_shape[-1], -1])  # [out_channels, N]

            W_sn, u, v = power_iteration(W_mat, self.u)

            if training:
                # Update estimated 1st singular vector
                self.u.assign(u)

            return self.kernel / W_sn
        else:
            return self.kernel

    def call(self, inputs, training=True):
        inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
        shape = inputs.get_shape().as_list()
        if len(shape) > 2:
            # Broadcasting is required for the inputs.
            outputs = standard_ops.tensordot(inputs, self.compute_spectral_normal(training), [[len(shape) - 1],
                                                                   [0]])
            # Reshape the output back to the original ndim of the input.
            if not context.executing_eagerly():
                output_shape = shape[:-1] + [self.units]
                outputs.set_shape(output_shape)
        else:
            outputs = gen_math_ops.mat_mul(inputs, self.kernel)
        if self.use_bias:
            outputs = nn.bias_add(outputs, self.bias)
        if self.activation is not None:
            return self.activation(outputs)  # pylint: disable=not-callable
        return outputs

#### custom convolution-2D function with spectral norm applied

In [0]:
# dense layer with spectral norm toggle
@tf_export('keras.layers.Conv2D', 'keras.layers.Convolution2D')
class Conv2D(tf_convolutional_layers.Conv2D, Layer):
      
    def __init__(self,
                 filters,
                 kernel_size,
                 strides=(1, 1),
                 padding='valid',
                 data_format=None,
                 dilation_rate=(1, 1),
                 activation=None,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 spectral_normalization=True,
                 bias_constraint=None,
                 **kwargs):
      
        if data_format is None:
            data_format = K.image_data_format()
            
        super(Conv2D, self).__init__(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=activations.get(activation),
            use_bias=use_bias,
            kernel_initializer=initializers.get(kernel_initializer),
            bias_initializer=initializers.get(bias_initializer),
            kernel_regularizer=regularizers.get(kernel_regularizer),
            bias_regularizer=regularizers.get(bias_regularizer),
            activity_regularizer=regularizers.get(activity_regularizer),
            kernel_constraint=constraints.get(kernel_constraint),
            bias_constraint=constraints.get(bias_constraint),
            **kwargs)

        self.u = K.random_normal_variable([1, filters], 0, 1, dtype=self.dtype, name="sn_estimate")  # [1, out_channels]
        self.spectral_normalization = spectral_normalization

    def compute_spectral_normal(self, training=True):
        # Spectrally Normalized Weight
        if self.spectral_normalization:
            # Get kernel tensor shape [kernel_h, kernel_w, in_channels, out_channels]
            W_shape = self.kernel.shape.as_list()

            # Flatten the Tensor
            W_mat = K.reshape(self.kernel, [W_shape[-1], -1])  # [out_channels, N]

            W_sn, u, v = power_iteration(W_mat, self.u)

            if training:
                # Update estimated 1st singular vector
                self.u.assign(u)

            return self.kernel / W_sn
        else:
            return self.kernel

    def call(self, inputs, training=None):

        outputs = K.conv2d(
            inputs,
            self.compute_spectral_normal(training),
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate)

        if self.bias is not None:
            outputs = K.bias_add(
                outputs,
                self.bias,
                data_format=self.data_format)

        if self.activation is not None:
            return self.activation(outputs)

        return outputs

    def get_config(self):
        config = {
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'strides': self.strides,
            'padding': self.padding,
            'data_format': self.data_format,
            'dilation_rate': self.dilation_rate,
            'activation': activations.serialize(self.activation),
            'use_bias': self.use_bias,
            'kernel_initializer': initializers.serialize(self.kernel_initializer),
            'bias_initializer': initializers.serialize(self.bias_initializer),
            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
            'activity_regularizer':
                regularizers.serialize(self.activity_regularizer),
            'kernel_constraint': constraints.serialize(self.kernel_constraint),
            'bias_constraint': constraints.serialize(self.bias_constraint)
        }
        base_config = super(Conv2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


#### custom convolution-2D transpose function with spectral norm applied

In [0]:
# dense layer with spectral norm toggle
@tf_export('tf.keras.layers.Conv2DTranspose',
           'tf.keras.layers.Convolution2DTranspose')
class Conv2DTranspose(tf_convolutional_layers.Conv2DTranspose, Layer):
  
    def __init__(self,
                 filters,
                 kernel_size,
                 strides=(1, 1),
                 padding='valid',
                 data_format=None,
                 activation=None,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 spectral_normalization=True,
                 **kwargs):
      
        if data_format is None:
            data_format = K.image_data_format()
            
        super(Conv2DTranspose, self).__init__(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            activation=activations.get(activation),
            use_bias=use_bias,
            kernel_initializer=initializers.get(kernel_initializer),
            bias_initializer=initializers.get(bias_initializer),
            kernel_regularizer=regularizers.get(kernel_regularizer),
            bias_regularizer=regularizers.get(bias_regularizer),
            activity_regularizer=regularizers.get(activity_regularizer),
            kernel_constraint=constraints.get(kernel_constraint),
            bias_constraint=constraints.get(bias_constraint),
            **kwargs)
        
        self.spectral_normalization = spectral_normalization
        self.u = K.random_normal_variable([1, filters], 0, 1, dtype=self.dtype, name="sn_estimate")  # [1, out_channels]

    def compute_spectral_normal(self, training=True):
        # Spectrally Normalized Weight

        if self.spectral_normalization:
            # Get the kernel tensor shape
            W_shape = self.kernel.shape.as_list()

            # Flatten the Tensor
            # For transpose conv, the kernel shape is [H,W,Out,In]
            W_mat = K.reshape(self.kernel, [W_shape[-2], -1])  # [out_c, N]

            sigma, u, v = power_iteration(W_mat, self.u)

            if training:
                # Update estimated 1st singular vector
                self.u.assign(u)

            return self.kernel / sigma
        else:
            return self.kernel

    # Overwrite the call() method to include Spectral normalization call
    def call(self, inputs, training=True):
        inputs_shape = array_ops.shape(inputs)
        batch_size = inputs_shape[0]
        if self.data_format == 'channels_first':
            c_axis, h_axis, w_axis = 1, 2, 3
        else:
            c_axis, h_axis, w_axis = 3, 1, 2

        height, width = inputs_shape[h_axis], inputs_shape[w_axis]
        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides

        # Infer the dynamic output shape:
        out_height = utils.deconv_output_length(height,
                                                kernel_h,
                                                self.padding,
                                                stride_h)
        out_width = utils.deconv_output_length(width,
                                               kernel_w,
                                               self.padding,
                                               stride_w)
        if self.data_format == 'channels_first':
            output_shape = (batch_size, self.filters, out_height, out_width)
            strides = (1, 1, stride_h, stride_w)
        else:
            output_shape = (batch_size, out_height, out_width, self.filters)
            strides = (1, stride_h, stride_w, 1)

        output_shape_tensor = array_ops.stack(output_shape)
        outputs = nn.conv2d_transpose(
            inputs,
            self.compute_spectral_normal(training=training),
            output_shape_tensor,
            strides,
            padding=self.padding.upper(),
            data_format=utils.convert_data_format(self.data_format, ndim=4))

        if not context.executing_eagerly():
            # Infer the static output shape:
            out_shape = inputs.get_shape().as_list()
            out_shape[c_axis] = self.filters
            out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis],
                                                           kernel_h,
                                                           self.padding,
                                                           stride_h)
            out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis],
                                                           kernel_w,
                                                           self.padding,
                                                           stride_w)
            outputs.set_shape(out_shape)

        if self.use_bias:
            outputs = nn.bias_add(
                outputs,
                self.bias,
                data_format=utils.convert_data_format(self.data_format, ndim=4))

        if self.activation is not None:
            return self.activation(outputs)
        return outputs

    def get_config(self):
        config = {
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'strides': self.strides,
            'padding': self.padding,
            'data_format': self.data_format,
            'activation': activations.serialize(self.activation),
            'use_bias': self.use_bias,
            'kernel_initializer': initializers.serialize(self.kernel_initializer),
            'bias_initializer': initializers.serialize(self.bias_initializer),
            'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
            'bias_regularizer': regularizers.serialize(self.bias_regularizer),
            'activity_regularizer':
                regularizers.serialize(self.activity_regularizer),
            'kernel_constraint': constraints.serialize(self.kernel_constraint),
            'bias_constraint': constraints.serialize(self.bias_constraint),
            'spectral_normalization': self.spectral_normalization
        }
        base_config = super(Conv2DTranspose, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


#### self-attention layer definition

In [0]:
def hw_flatten(x):
    # Input shape x: [BATCH, HEIGHT, WIDTH, CHANNELS]
    # flat the feature volume across the tensor width and height
    x_shape = tf.shape(x)
    return tf.reshape(x, [x_shape[0], -1, x_shape[-1]]) # return [BATCH, W*H, CHANNELS]
  

class SelfAttention(tf.keras.Model):
    
    def __init__(self, number_of_filters, dtype):
        super(SelfAttention, self).__init__()

        self.f = Conv2D(filters=number_of_filters//8, 
                        kernel_size=1, 
                        strides=1,
                        spectral_normalization=True,                 
                        padding='SAME', 
                        name='f_x',
                        activation=None, 
                        dtype=dtype)

        self.g = Conv2D(filters=number_of_filters//8, 
                        kernel_size=1, 
                        strides=1,
                        spectral_normalization=True,                 
                        padding='SAME', 
                        name='g_x',
                        activation=None, 
                        dtype=dtype)

        self.h = Conv2D(filters=number_of_filters, 
                        kernel_size=1, 
                        strides=1,
                        spectral_normalization=True,                 
                        padding='SAME',
                        name='h_x',
                        activation=None, 
                        dtype=dtype)

        self.gamma = tfe.Variable(0., dtype=dtype, trainable=True, name="gamma")
        self.flatten = tf.keras.layers.Flatten()

        
    def call(self, x):
        
        f = self.f(x)
        g = self.g(x)
        h = self.h(x)

        f_flatten = hw_flatten(f)
        g_flatten = hw_flatten(g)
        h_flatten = hw_flatten(h)
        
        s = tf.matmul(a=g_flatten, 
                      b=f_flatten, 
                      transpose_b=True) # [B,N,C] * [B, N, C] = [B, N, N]

        b = tf.nn.softmax(s, axis=-1)
        o = tf.matmul(b, h_flatten)
        y = self.gamma * tf.reshape(o, tf.shape(x)) + x

        return y

#### discriminator model class definition

In [0]:
APPLY_SPECTRAL_NORM=True
APPLY_SELF_ATTENTION=True

In [0]:
class Discriminator(tf.keras.Model):    
    """
    Create the discriminator network
    :param images: Tensor of input image(s)
    :param alpha: Scalar value specifying the degree of leakage in leaky relu
    :param reuse: Boolean if the weights should be reused
    :return: Tuple of (tensor output of the discriminator, tensor logits of the discriminator)
    """
    def __init__(self, alpha, dtype):
        super(Discriminator, self).__init__()
        self.alpha = alpha
        # block 1
        self.conv1 = Conv2D(filters=32, kernel_size=4, strides=2, padding='same', 
                            data_format='channels_last', use_bias=True, 
                            spectral_normalization=APPLY_SPECTRAL_NORM,
                            activation=None, name='d_conv1')
        self.conv2 = Conv2D(filters=64, kernel_size=4, strides=2, padding='same', 
                            data_format='channels_last', use_bias=True, 
                            spectral_normalization=APPLY_SPECTRAL_NORM,
                            activation=None, name='d_conv2')
        
        self.attention = SelfAttention(64, dtype=dtype)
        
        self.conv3 = Conv2D(filters=128, kernel_size=4, strides=2, padding='same', 
                            data_format='channels_last', use_bias=True, 
                            spectral_normalization=APPLY_SPECTRAL_NORM,
                            activation=None, name='d_conv3')
        self.conv4 = Conv2D(filters=256, kernel_size=4, strides=2, padding='same', 
                            data_format='channels_last', use_bias=True, 
                            spectral_normalization=APPLY_SPECTRAL_NORM,
                            activation=None, name='d_conv4')
        self.conv5 = Conv2D(filters=512, kernel_size=4, strides=2, padding='same', 
                            data_format='channels_last', use_bias=True,
                            spectral_normalization=APPLY_SPECTRAL_NORM,
                            activation=None, name='d_conv5')
        self.flat = tf.keras.layers.Flatten()
        self.fc1 = Dense(units=1, dtype=dtype, 
                         spectral_normalization=APPLY_SPECTRAL_NORM, 
                         activation=None, name='d_logits')


    def call(self, inputs, is_training):
        # block 1 operation
        net = self.conv1(inputs, training=is_training)
        net = tf.nn.leaky_relu(net, alpha=self.alpha)
        # block 2 operation
        net = self.conv2(net, training=is_training)
        net = tf.nn.leaky_relu(net, alpha=self.alpha)
        
        if APPLY_SELF_ATTENTION:
            # self attention operation
            net = self.attention(net)
        
        # block 3 operation
        net = self.conv3(net, training=is_training)
        net = tf.nn.leaky_relu(net, alpha=self.alpha)
        # block 4 operation
        net = self.conv4(net, training=is_training)
        net = tf.nn.leaky_relu(net, alpha=self.alpha)
#         # block 5 operation
#         net = self.conv5(net)
#         net = tf.nn.leaky_relu(net, alpha=self.alpha)
        # logit output
        net = self.flat(net)
        logits = self.fc1(net)

        return logits
    

    def compute_loss(self, d_logits_real, d_logits_fake):
#         loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
#                                    logits=d_logits_real, 
#                                    labels=tf.ones_like(d_logits_real)))
#         loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
#                                    logits=d_logits_fake, 
#                                    labels=tf.zeros_like(d_logits_fake)))
#         return loss_real + loss_fake

        # Hinge loss
        real_loss = tf.reduce_mean(tf.nn.relu(1. - d_logits_real))
        fake_loss = tf.reduce_mean(tf.nn.relu(1. + d_logits_fake))

        return real_loss + fake_loss

#### generator model class definition

In [0]:
class Generator(tf.keras.Model):
    """
    Create the generator network
    :param z: Input z
    :param is_train: Boolean if generator is being used for training
    :return: The tensor output of the generator
    """
    def __init__(self, dtype):
        super(Generator, self).__init__()

        # init all layer components; note no actual computation is done
        # fully connected layer 1
        self.fc1 = Dense(units=4*4*512, dtype=dtype, activation='relu', 
                         spectral_normalization=APPLY_SPECTRAL_NORM,
                         name='g_fc1')
        # conv transpose + batch norm layer 1
        self.transp_conv1 = Conv2DTranspose(filters=512, kernel_size=4, strides=2, 
                                            spectral_normalization=APPLY_SPECTRAL_NORM,
                                            padding='same', activation=None, name='g_tr_conv1')
        self.bn1 = tf.keras.layers.BatchNormalization(scale=False, dtype=dtype, fused=False, name='g_bn1')
        # conv transpose + batch norm layer 2
        self.transp_conv2 = Conv2DTranspose(filters=128, kernel_size=4, strides=2,
                                            spectral_normalization=APPLY_SPECTRAL_NORM,
                                            padding='same', activation=None, name='g_tr_conv2')
        self.bn2 = tf.keras.layers.BatchNormalization(scale=False, dtype=dtype, fused=False, name='g_bn2')
        # conv transpose + batch norm layer 3
        self.transp_conv3 = Conv2DTranspose(filters=64, kernel_size=4, strides=2,
                                            spectral_normalization=APPLY_SPECTRAL_NORM,
                                            padding='same', activation=None, name='g_tr_conv3')
        self.bn3 = tf.keras.layers.BatchNormalization(scale=False, dtype=dtype, fused=False, name='g_bn3')
        
        # self attention layer
        self.attention = SelfAttention(64, dtype=dtype)
       
        # conv transpose + batch norm layer 4
        self.transp_conv4 = Conv2DTranspose(filters=32, kernel_size=4, strides=2,
                                            spectral_normalization=APPLY_SPECTRAL_NORM,
                                            padding='same', activation=None, name='g_tr_conv4')
        self.bn4 = tf.keras.layers.BatchNormalization(scale=False, dtype=dtype, fused=False, name='g_bn4')
#         # conv transpose + batch norm layer 5
#         self.transp_conv5 = Conv2DTranspose(filters=16, kernel_size=4, strides=2,
#                                             padding='same', activation=None, name='g_tr_conv5')
#         self.bn5 = tf.keras.layers.BatchNormalization(scale=False, dtype=dtype, fused=False, name='g_bn5')        
        # conv2D
        self.conv = Conv2D(filters=3, kernel_size=3, strides=1, dtype=dtype, 
                           spectral_normalization=APPLY_SPECTRAL_NORM,
                           padding='same', activation=None, name='g_conv')
        self.out = tf.keras.layers.Activation(activation='tanh', name='g_out')


    def call(self, z, is_training):

        net = self.fc1(z)
        net = tf.reshape(net, (-1,4,4,512), name='g_fc1_reshape')
        # first layer operation        
        net = self.transp_conv1(net, training=is_training)
        net = self.bn1(net, training=is_training)
        net = tf.nn.relu(net)
        # second layer operation
        net = self.transp_conv2(net, training=is_training)
        net = self.bn2(net, training=is_training)
        net = tf.nn.relu(net)
        # third layer operation
        net = self.transp_conv3(net, training=is_training)
        net = self.bn3(net, training=is_training)
        net = tf.nn.relu(net)
        
        if APPLY_SELF_ATTENTION:
            # self attention operation
            net = self.attention(net)
        
        # fourth layer operation
        net = self.transp_conv4(net, training=is_training)
        net = self.bn4(net, training=is_training)
        net = tf.nn.relu(net)
#         # fifth layer operation
#         net = self.transp_conv5(net)
#         net = self.bn5(net)
#         net = tf.nn.relu(net)        
        # output layer operation
        net = self.conv(net, training=is_training)
        output = self.out(net)

        return output


    def compute_loss(self, logits):
#         loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
#                               logits=logits, 
#                               labels=tf.ones_like(logits)))
#         return loss
      
        return - tf.reduce_mean(d_logits_fake)


## 5 Model Training

#### training paramenters

In [0]:
VERSION = "v1"
MODEL_NAME = "SAGAN_4dlr_hingeloss_64px_trial2"
MODEL_DIR = "SAGAN/trained_models/{}".format(MODEL_NAME)
LOG_DIR = os.path.join(MODEL_DIR, 'logs')
SRC_TFRECORD_PATH = "/content/SAGAN/dataset/celeba_dataset"    # "SAGAN/dataset/celeba_dataset"
IMG_WIDTH_PIXEL=64
IMG_HEIGHT_PIXEL=64
IMG_CHANNELS=3
Z_SIZE=64
SUMMARY_PER_N_STEPS=2500
SAVE_MODEL_PER_N_STEPS=2500
SAVE_IMG_PER_N_STEPS=1000
MAX_TRAINING_STEPS=100000

#### model hyper-paramenters

In [0]:
# alpha param for leaky ReLu
ALPHA = 0.2
# learning rates
G_LR = 0.0001
D_LR = 0.0004
# beta params for the Adam optimizer
BETA1 = 0.0
BETA2 = 0.9
# training batch size & epoch
BATCH_SIZE = 64
EPOCHS = 100

#### instantiating objects

- TF dataset object

In [0]:
dataset = create_tfdataset(tfrecord_file=SRC_TFRECORD_PATH+'.tfrecord',
                           img_height=IMG_HEIGHT_PIXEL,
                           img_width=IMG_WIDTH_PIXEL,
                           shuffle_buffer=512, 
                           epochs=EPOCHS,
                           batch_size=BATCH_SIZE, 
                           pThreads=4)

- Generator & Discriminator model objects

In [0]:
# generator network
g_net = Generator(dtype='float32')
# generator optimizer
g_optimizer = tf.train.AdamOptimizer(learning_rate=G_LR, 
                                     beta1=BETA1, 
                                     beta2=BETA2)

# discriminator network
d_net = Discriminator(alpha=ALPHA, dtype='float32')
# discriminator optimizer
d_optimizer = tf.train.AdamOptimizer(learning_rate=D_LR, 
                                     beta1=BETA1, 
                                     beta2=BETA2)

#### tensorboard summary & checkpoint writers

In [0]:
# set up tensorboard writer
tf_board_writer = tf.contrib.summary.create_file_writer(LOG_DIR)
tf_board_writer.set_as_default()

In [0]:
img_sample_dir = os.path.join(MODEL_DIR, 'img_samples')
if not os.path.isdir(img_sample_dir):
    os.mkdir(img_sample_dir)

# set up checkpoint directories
global_step = tf.train.get_or_create_global_step()

# generator checkpoint directory 
g_checkpoint_dir = os.path.join(MODEL_DIR, 'generator')
g_root = tfe.Checkpoint(optimizer=g_optimizer, 
                        model=g_net,
                        optimizer_step=global_step)

# discriminator checkpoint directory 
d_checkpoint_dir = os.path.join(MODEL_DIR, 'discriminator')
d_root = tfe.Checkpoint(optimizer=d_optimizer, 
                        model=d_net,
                        optimizer_step=global_step)

#### load from previous checkpoint (if exist)

In [0]:
# copy previous results over from drive
!cp -r /content/drive/'My Drive'/colab_notebooks/SAGAN/trained_models/SAGAN_4dlr_hingeloss_64px_trial2 /content/SAGAN/trained_models

In [98]:
# restore generator/discriminator from previous checkpoints (if exist)
if os.path.exists(MODEL_DIR):
    try:
        g_root.restore(tf.train.latest_checkpoint(g_checkpoint_dir))
    except Exception as ex:
        print('Could not load Generator model from {}'.format(g_checkpoint_dir))
    try:
        d_root.restore(tf.train.latest_checkpoint(d_checkpoint_dir))
    except Exception as ex:
        print('Could not load Discriminator model from {}'.format(d_checkpoint_dir))
    global_step = tf.train.get_or_create_global_step()
    print('Resuming training from latest checkpoint')
    print('Generator and Discriminator models loaded from global step: {}'.\
          format(tf.train.get_or_create_global_step().numpy()))
else:
    print('Model folder not found.')

Resuming training from latest checkpoint
Generator and Discriminator models loaded from global step: 100000


#### training

In [0]:
# # fixed eval_z vector
# eval_z = tf.random_normal(shape=(BATCH_SIZE, Z_SIZE), dtype='float32')

# # save to file for next time
# np.save(img_sample_dir+'/z_input.npy', eval_z.numpy()) 

# load from file
eval_z = np.load(img_sample_dir+'/z_input.npy')

In [23]:
for _, (batch_real_images) in enumerate(dataset):

  # construct random normal z input to feed into generator
  input_z = tf.random_normal(shape=(BATCH_SIZE, Z_SIZE), dtype='float32')

  with tf.contrib.summary.record_summaries_every_n_global_steps(SUMMARY_PER_N_STEPS):

      # define gradient tapes to start recording computation operations
      with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
          # run generator net with random normal z input to generate image batch
          g_fake_images = g_net(input_z, is_training=True)  
          # run discriminator net with real input images
          d_logits_real = d_net(batch_real_images, is_training=True)
          # run discriminator net with generated fake images
          d_logits_fake = d_net(g_fake_images, is_training=True)
          # compute generator loss by feeding back the discriminator logits output
          g_loss = g_net.compute_loss(d_logits_fake)
          # compute discriminator hinge loss
          d_loss = d_net.compute_loss(d_logits_real, d_logits_fake)

      # write losses to tensorboard as scalars & generated images
      tf.contrib.summary.scalar('generator_loss', g_loss)
      tf.contrib.summary.scalar('discriminator_loss', d_loss)
      tf.contrib.summary.image('generator_image', tf.to_float(g_fake_images), max_images=9)

      # get all discriminator variables (quantities to optimize)
      d_variables = d_net.variables
      # compute d(d_loss)/dx; where x is all discriminator variables
      d_grads = d_tape.gradient(d_loss, d_variables)

      # get all generator variables (quantities to optimize)
      g_variables = g_net.variables
      # compute d(g_loss)/dx; where x is all generator variables
      g_grads = g_tape.gradient(g_loss, g_variables)

      # update all variables
      d_optimizer.apply_gradients(zip(d_grads, d_variables),
                                  global_step=global_step)
      g_optimizer.apply_gradients(zip(g_grads, g_variables),
                                  global_step=global_step)

  # output training status
  counter = global_step.numpy()
  if counter % 100==0:
      print('training step {}: discriminator loss {}; generator loss {}'.format(counter, d_loss, g_loss))

  # TRAINING PROCESS CONTROL FLOW
  # every X steps, generate a batch of images
  if counter % SAVE_IMG_PER_N_STEPS==0:
      print('Current step:{}'.format(counter))
      with tf.contrib.summary.always_record_summaries():
          g_sample_images = g_net(eval_z, is_training=False)
          tf.contrib.summary.image('fixed_latent_generator_image', 
                                   tf.to_float(g_sample_images), 
                                   max_images=BATCH_SIZE)
      # save image as numpy array
      np.save(img_sample_dir+'/{}.npy'.format(counter), g_sample_images.numpy()) 
          
  if counter % SAVE_MODEL_PER_N_STEPS==0:
      print('Saving model snapshot')
      g_root.save(file_prefix=os.path.join(g_checkpoint_dir, "model.ckpt"))
      d_root.save(file_prefix=os.path.join(d_checkpoint_dir, "model.ckpt"))
      
  # save generator model at end of training
  if counter >= MAX_TRAINING_STEPS:
      print('End of training; saving models')
      g_root.save(file_prefix=os.path.join(g_checkpoint_dir, "model.ckpt"))
      d_root.save(file_prefix=os.path.join(d_checkpoint_dir, "model.ckpt"))
      sys.exit()

training step 100: discriminator loss 0.5758100152015686; generator loss 0.5085903406143188
training step 200: discriminator loss 0.31599318981170654; generator loss 0.7916896939277649
training step 300: discriminator loss 1.2474250793457031; generator loss 1.2834632396697998
training step 400: discriminator loss 1.3714842796325684; generator loss -0.09806103259325027
training step 500: discriminator loss 0.9476824402809143; generator loss 0.504289984703064
training step 600: discriminator loss 1.208711862564087; generator loss 0.5330641269683838
training step 700: discriminator loss 1.1163430213928223; generator loss 0.20929700136184692
training step 800: discriminator loss 0.7843515872955322; generator loss 0.48741888999938965
training step 900: discriminator loss 1.3531945943832397; generator loss 0.17101359367370605
training step 1000: discriminator loss 1.4328596591949463; generator loss 0.46320053935050964
Current step:1000
training step 1100: discriminator loss 2.484928607940674

SystemExit: ignored

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


#### copy training log and model checkpoints to G-drive

In [0]:
# copy results over to G-Drive
!cp -r /content/SAGAN/trained_models/ /content/drive/'My Drive'/colab_notebooks/SAGAN/

## 6 Generate Training Progression Video

In [0]:
from sklearn.preprocessing import MinMaxScaler
from glob import glob

In [0]:
def scaleImg(dataset):
    scaler = MinMaxScaler()
    image = scaler.fit_transform(dataset.reshape(dataset.shape[0]*dataset.shape[1], dataset.shape[2]))
    image = image.reshape(dataset.shape[0], dataset.shape[1], dataset.shape[2])
    return image

  
def normalizeImg(img):
    # img is expected to be 4D tensor with [batch, height, width, color]
    # normalize pixel intensity in each image in batch
    for i in range(img.shape[0]):
        img[i,:,:,:] = scaleImg(img[i,:,:,:])
        img[i,:,:,:] = img[i,:,:,:] / img[i,:,:,:].max() # extra normalization step
    # convert to tensor
    img = tf.constant(img)
    # convert to image grid
    img_grid = tf.contrib.gan.eval.image_grid(
                    input_tensor=img,
                    grid_shape=(8,8),
                    image_shape=(64, 64),
                    num_channels=3
                )
    return img_grid.numpy()
  
  
def save_img_progress(file_prefix):
    files = glob(file_prefix)
    for f in files:
        # load image array from file
        img = np.load(f)
        # normalize 
        img_grid = normalizeImg(img)
        # save img grid to jpg
        plt.imsave(f.replace('.npy','.jpg'), img_grid[0,:,:,:])

In [0]:
save_img_progress(file_prefix='/content/SAGAN/trained_models/VANGAN_4dlr_hingeloss_64px/img_samples/*0.npy')

In [37]:
!pip install imageio

Collecting imageio
[?25l  Downloading https://files.pythonhosted.org/packages/28/b4/cbb592964dfd71a9de6a5b08f882fd334fb99ae09ddc82081dbb2f718c81/imageio-2.4.1.tar.gz (3.3MB)
[K    100% |████████████████████████████████| 3.3MB 10.0MB/s 
Building wheels for collected packages: imageio
  Running setup.py bdist_wheel for imageio ... [?25l- \ | / done
[?25h  Stored in directory: /root/.cache/pip/wheels/e0/43/31/605de9372ceaf657f152d3d5e82f42cf265d81db8bbe63cde1
Successfully built imageio
Installing collected packages: imageio
Successfully installed imageio-2.4.1


In [0]:
# make video
import imageio
frames = []

image_folder = '/content/SAGAN/trained_models/{}/img_samples/'.format(MODEL_NAME)
filenames = ['{}.jpg'.format(int(i)) for i in np.arange(1e3,1e5,1e3)]

for f in filenames:
    file = os.path.join(image_folder, f)
    frames.append(imageio.imread(file))
imageio.mimsave('/content/SAGAN/trained_models/{}/img_samples/progression.gif'.format(MODEL_NAME), frames)


In [0]:
# copy results over to G-Drive
!cp -r /content/SAGAN/trained_models/VANGAN_4dlr_hingeloss_64px/img_samples /content/drive/'My Drive'/colab_notebooks/SAGAN/trained_models/VANGAN_4dlr_hingeloss_64px/