# Start of Coding

## Import python libraries

In [0]:
%tensorflow_version 1.x
import tensorflow as tf
import numpy as np
import keras.datasets.mnist as mnist
import math
import os
import matplotlib.pyplot as plt
import tensorflow.contrib.slim as slim
import imageio

# This is needed for access the google cloud storage to save/restore the model and load the training/testing data.
from google.colab import auth
auth.authenticate_user()

Using TensorFlow backend.


The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



In [0]:
class Flags:
    """
    @author: N.Gu
    this model has 7 stages in total, each stage represents a resolution, in total the resolution setting contains:
    4x4, 8x8, 16x16, ..., 256x256. This can be easily up- sacled to 1024x1024, as long as the 1024x1024 
    training data is available!
    """
    def __init__(self):
        self.dataset =  "celeba" 
        self.z_dim = 512  
        self.size_list = [ 4, 8, 16, 32, 64, 128, 256, 512 ]
        self.generator_channel_list =  [ 512, 512, 512, 512, 256, 128, 64 , 32 ] 
        self.discriminator_in_channel_list =  [   512, 512, 512, 512, 256, 128, 64, 32 ]
        self.discriminator_out_channel_list  = [ 512, 512, 512, 512, 512, 256, 128, 64  ]
        self.encoder_in_channel_list = self.discriminator_in_channel_list 
        self.encoder_out_channel_list = self.discriminator_out_channel_list
        self.lr_list = [  1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-3, 1e-3  ]
        # for different resolution different batchsize is used, due to to the limitation of memory
        # smaller batchsize is benefitial to the resolution transition
        self.batch_size_list = [128*8, 64*8, 64*8, 32*8, 32*8, 16*8, 8*8, 2*8]


        # let the model see 800k images for each stage(or substage)
        self.num_images_per_stage = 800000
        # the training steps for each resolution stage. Except the first resolution, all the other resolution
        # contains "resolution transition"+"resolution stabilization" two substages
        self.train_steps_list = [ int( self.num_images_per_stage/self.batch_size_list[0] ) ]+ \
                            [int( self.num_images_per_stage*2/batch_size ) for batch_size in self.batch_size_list[1:] ]
        self.tpu = 'grpc://' + os.environ['COLAB_TPU_ADDR']
        # The path to save the model on google cloud storage
        self.model_dir = "gs://reverse-vae-celeba-hq/model/"
        # The path sto load the data on GDS
        self.data_dir = "gs://reverse-vae-celeba-hq/data/"
        # How many batches sent to TPU for one run before return back to the host CPU to do some evaluation and possibly saving model
        self.iterations_per_loop = 200
        self.use_tpu = True
        # use eaulized learning rate adopted by PGGAN
        self.is_equalized_learning_rate = True

FLAGS = Flags()


## Some helper function for visulizing the results.

In [0]:
def add_padding( x, padding_size=(8,8,8,8), padding_value = 1 ):
    # x is a 4 d ndarray with range [0,1]
    background = padding_value * np.ones(  [ x.shape[0], x.shape[1]+ padding_size[0]+padding_size[2], x.shape[2] + padding_size[1]+padding_size[3], x.shape[3]   ]  ).astype(np.float32)
    background[:, padding_size[0]:-padding_size[2], padding_size[1]:-padding_size[3], : ] = x
    padded_x = background
    return padded_x

# to convert a bulk of images into grid of images
def make_grid(images,  ncol= None):
	# ncol represents the number of columns of the image grid, if ncol is None, then arrange the grid as close to a square as possible
	# This function always assume that the input image is RGB color space , normalized float type
	
	if np.max(images)-np.min(images) >1 :
		images = np.clip( images, -1,1 )
		images = images /2 +0.5
		
	image_num = images.shape[0]
	num_h = None
	num_w = None
	im_h = images.shape[1]
	im_w = images.shape[2]   
	im_c = images.shape[3]
	if (ncol==None):
		num_w = int( np.ceil(np.sqrt(image_num )))
		num_h = int( np.ceil( image_num/ num_w ))
	else:
		num_w = int(ncol)
		num_h = int( np.ceil(  image_num/num_w ))

	# create a white pannel, which is a [height, width, channel] ndarray
	pannel = np.ones(( num_h * im_h, num_w * im_w , im_c )).astype(np.float32)

	for i in range( image_num ):
		start_h = int(i / num_w) * im_h
		start_w = (i % num_w) * im_w
		pannel[ start_h: start_h+im_h , start_w : start_w + im_w ,: ]= images[i,:,:,:]
	return  pannel

## Prepeare the input_functions for tf estimator

Before moving futher, one should prepare 0.17 million 256x256 celeba images and store them into the GDS as stated in FLAGS.data_dir. The celeba images should be normalized to [-1,1] and stored as TFRecord. There should be 17 Tfrecords, each TfRecord contains 10000 images. The first 15 TFRecords are used for training, and the last two are used for evaluation and testing

In [0]:
if FLAGS.dataset == "celeba":
    celeba_tfrecord_dics = {
            # 'image': tf.VarLenFeature(dtype=tf.float32),  # This is not compatible with TPU
            'image': tf.FixedLenFeature( shape=( 1024,1024, 3 ), dtype=tf.float32), 
            'image_shape': tf.FixedLenFeature(shape=(3,), dtype=tf.int64), 
    }

#### The input function part:

In [0]:
 """the g_w is used to control when to update the generator's parameters
       here the pattern "0 0 0 0 1" means training the generator once for every n_critic = 5 batches as stated in WGAN-GP
       This setting is due to the special property of the TPU mechanism
       In PG-ReverseVAE this usage is deactivated since n_critic = 1 as used in PGGAN, but it's good to keep it for future usage: 
       e.g explore the influence of n_critic on performance. 
"""
def train_input_fn( dataset, stage,  batch_size ):
    if dataset == "celeba":
        dataset_x_train = tf.data.TFRecordDataset( FLAGS.data_dir+"celeba_hq_eval.tfrecord", compression_type='GZIP')
        dataset_x_train = dataset_x_train.shuffle(60000).repeat()
        pattern = np.array([0,0,0,0,1]).repeat(batch_size).astype(np.float32)
        dataset_g_w = tf.data.Dataset.from_tensor_slices( { "g_w": pattern  } ).repeat()
        dataset_output = tf.data.Dataset.from_tensor_slices( ( np.zeros( ( batch_size,1 ) ).astype(np.float32) ) ).repeat()
        ds = tf.data.Dataset.zip(( dataset_x_train, dataset_g_w, dataset_output))

        def map_func(a,b,c):
            parsed_example = tf.parse_single_example(a, celeba_tfrecord_dics)
            parsed_example['image'] = tf.reshape(parsed_example['image'], [1, 1024,1024,3] )*2-1   # the value is in the range of [-1,1]   
            a = {"x": tf.squeeze(tf.image.resize_bilinear(  parsed_example["image"]  , [FLAGS.size_list[stage],FLAGS.size_list[stage]], align_corners=True), axis =0)  }            
            a.update(b)
            return a, c
        ds = ds.map(map_func).batch( batch_size, drop_remainder = True ).prefetch(buffer_size=1)
    return ds


In [0]:
def eval_input_fn(dataset, stage,batch_size):
    if dataset == "celeba":

        dataset_x_eval = tf.data.TFRecordDataset( FLAGS.data_dir+"celeba_hq_eval.tfrecord"  , compression_type='GZIP')
        dataset_x_eval = dataset_x_eval.shuffle(10000).repeat()
        dataset_output = tf.data.Dataset.from_tensor_slices( ( np.zeros( ( batch_size ) ).astype(np.float32) ) ).repeat()
        ds = tf.data.Dataset.zip(( dataset_x_eval, dataset_output))
        def map_func(a,b):
            parsed_example = tf.parse_single_example(a, celeba_tfrecord_dics)
            parsed_example['image'] = tf.reshape(parsed_example['image'],  [1,1024,1024,3] )*2-1 
            return tf.squeeze(tf.image.resize_bilinear(  parsed_example["image"]  , [FLAGS.size_list[stage],FLAGS.size_list[stage]], align_corners=True), axis =0)  , b
        ds = ds.map(map_func).batch(batch_size, drop_remainder = True).prefetch(buffer_size =1)
    
    return ds

In [0]:
def predict_input_fn(dataset, stage,batch_size):
    """An input function for test recontruction"""

    if dataset == "celeba":
        dataset = tf.data.TFRecordDataset( FLAGS.data_dir+"celeba_hq_eval.tfrecord" , compression_type='GZIP').shuffle(1000).take(64)
        def map_func(a):
            parsed_example = tf.parse_single_example(a, celeba_tfrecord_dics)
            parsed_example['image'] = tf.reshape(parsed_example['image'],  [1, 1024,1024,3] )*2-1
            return tf.squeeze(tf.image.resize_bilinear(  parsed_example["image"]  , [FLAGS.size_list[stage],FLAGS.size_list[stage]], align_corners=True), axis =0), tf.zeros(shape= (1,))
        dataset = dataset.map( map_func )
        dataset = dataset.batch(batch_size,drop_remainder = False)
    return dataset

## Design the PG-ReverseVAE model

#### The metric function used for evaluation:

In [0]:
def metric_fn(loss_gen, loss_dis, W_dis, loss_recon_z ):
    """Function to return metrics for evaluation.
    The input parameters can be arbritary
    """
    return {"loss_gen": tf.metrics.mean(loss_gen), 
            "loss_dis": tf.metrics.mean(loss_dis),
            "wasserstein_distance": tf.metrics.mean( W_dis ),
            "loss_recon_z": tf.metrics.mean( loss_recon_z )
            }

#### The model structure part:

In [0]:
def pixelnorm( net ):
    return tf.divide(net, tf.sqrt(tf.reduce_mean( net**2 , axis = 3, keepdims= True) + 1e-8))

In [0]:
def minibatch_std_dev( net ):

    y = net - tf.reduce_mean( net, axis = 0, keepdims= True  )
    y = tf.reduce_mean(  y**2, axis= 0, keepdims= False )
    y = tf.sqrt( y + 1e-8 )
    std_dev = tf.reduce_mean( y) * tf.ones( [tf.shape(net)[0], tf.shape(net)[1], tf.shape(net)[2], 1 ] )
    net = tf.concat( [ net, std_dev ], axis = 3 )
    return net

In [0]:
def activate(net, activation, alpha=None ):
    if activation == "leaky_relu":
        if alpha is None:
            alpha = 0.2
        net = tf.nn.leaky_relu(net, alpha= alpha )
    elif activation == "relu":
        net = tf.nn.relu(net )
    elif activation == "sigmoid":
        net = tf.nn.sigmoid(net)
    else:
        print("unrecognized activation!")
        exit(1)
    return net

In [0]:
def conv_layer(net,  trainable ,  units  , kernel, strides,  dropout_rate = 0,  activation = None , alpha= None, var_name = None  ):
    if not FLAGS.is_equalized_learning_rate:
        net = tf.layers.Conv2D(units, kernel, strides,"same", trainable= trainable)(net)
    else: # this is for equalized learning rate implementation
        n_feats_in = net.get_shape().as_list()[-1]
        fan_in = kernel * kernel * n_feats_in
        c = tf.constant(np.sqrt(2. / fan_in), dtype=tf.float32)
        kernel_init = tf.random_normal_initializer(stddev=1.)
        w_shape = [kernel, kernel, n_feats_in, units]
        w = tf.get_variable('kernel_'+var_name , shape=w_shape, initializer=kernel_init, trainable= trainable)
        w = c * w
        net = tf.nn.conv2d(net, w, strides, padding="SAME" )
        b = tf.get_variable('bias_'+var_name, [units], initializer=tf.constant_initializer(0.), trainable= trainable  )
        net = tf.nn.bias_add(net, b)
    if activation is not None:
        net = activate( net, activation , alpha )
    if dropout_rate >0:
        net = tf.layers.Dropout(rate=dropout_rate)( net , training = trainable )
    return net

In [0]:
def toRGB( net, trainable, units, scope = "toRGB"):
    with tf.variable_scope(scope, reuse= tf.AUTO_REUSE ):
        net = conv_layer( net, trainable, units, 1, (1,1), 0 , None , None, "toRGB" )
        return net

def fromRGB( net, trainable, units, scope = "fromRGB" ):
    with tf.variable_scope(scope, reuse= tf.AUTO_REUSE ):
        net = conv_layer( net, trainable, units, 1, (1,1), 0 , "leaky_relu" , 0.2, "fromRGB" )
        return net

In [0]:
def generator( z, trainable = True, scope = "generator", dataset = "celeba", stage = 0, alpha = 0 ):
    with tf.variable_scope( scope, reuse= tf.AUTO_REUSE ):
        if dataset == "celeba":
            size_list =  FLAGS.size_list
            channel_list =   FLAGS.generator_channel_list 
            with tf.variable_scope( "stage0", reuse= tf.AUTO_REUSE):
                net = tf.layers.Dense( 4*4*512, activation=tf.nn.leaky_relu, trainable= trainable )(z)
                net = tf.reshape( net, [ tf.shape(net)[0], size_list[0] ,size_list[0], channel_list[0] ] )
                net = pixelnorm(conv_layer( net, trainable, channel_list[0], 3, (1,1), 0 , "leaky_relu", 0.2, "conv1" ))
                x = toRGB( net, trainable, 3 )

            for stg in range(1, stage+1):
                if stg == stage:
                    previous_x = x
                with tf.variable_scope( "stage%d"%(stg), reuse= tf.AUTO_REUSE ):
                    net = tf.image.resize_bilinear( net, [size_list[stg], size_list[stg]], align_corners= True )
                    net = pixelnorm(conv_layer( net, trainable, channel_list[stg], 3, (1,1), 0, "leaky_relu", 0.2 , "conv1"  ))
                    net = pixelnorm(conv_layer( net, trainable, channel_list[stg], 3, (1,1), 0, "leaky_relu", 0.2 , "conv2" ))
                    x = toRGB( net, trainable, 3 )
            if stage > 0:
                x = (1-alpha) * tf.image.resize_bilinear(previous_x, [ size_list[stage], size_list[stage]], align_corners= True ) + alpha * x
            
            return x

In [0]:
def discriminator( x,  trainable= True,scope = "discriminator", dataset = "celeba", stage = 0, alpha = 0 ):
    with tf.variable_scope( scope, reuse = tf.AUTO_REUSE  ):
        if dataset == "celeba":
            size_list = FLAGS.size_list
            in_channel_list =  FLAGS.discriminator_in_channel_list  
            out_channel_list = FLAGS.discriminator_out_channel_list 
            if stage > 0:
                with tf.variable_scope( "stage%d"%(stage), reuse=tf.AUTO_REUSE ):
                    net = fromRGB( x, trainable, in_channel_list[stage])
                    net = conv_layer( net, trainable, in_channel_list[stage], 3, (1,1), 0, "leaky_relu", 0.2, "conv1" )
                    net = conv_layer( net, trainable, out_channel_list[stage], 3, (1,1), 0, "leaky_relu", 0.2, "conv2")
                    net = tf.layers.AveragePooling2D( 2,2,"same" )(net)
                with tf.variable_scope( "stage%d"%(stage-1), reuse=tf.AUTO_REUSE ):
                    net_1 = fromRGB( tf.layers.AveragePooling2D(2,2,"same")(x), trainable, in_channel_list[stage-1] )
                net = (1-alpha) * net_1 + alpha * net
            else:
                with tf.variable_scope( "stage0", reuse= tf.AUTO_REUSE):
                    net = fromRGB( x, trainable, in_channel_list[0])
            
            for stg in range( stage-1, 0, -1 ):
                with tf.variable_scope( "stage%d"%(stg), reuse= tf.AUTO_REUSE ):
                    net = conv_layer( net, trainable, in_channel_list[stg], 3, (1,1), 0, "leaky_relu", 0.2, "conv1")
                    net = conv_layer( net, trainable, out_channel_list[stg], 3, (1,1), 0, "leaky_relu", 0.2, "conv2")
                    net = tf.layers.AveragePooling2D( 2,2,"same" )(net)

            with tf.variable_scope( "stage0", reuse= tf.AUTO_REUSE):
                net = minibatch_std_dev(net)
                net = conv_layer( net, trainable, in_channel_list[0], 3, (1,1), 0, "leaky_relu", 0.2, "conv1")
                net = tf.layers.Flatten()(net)
                net = tf.layers.Dense( out_channel_list[0], activation= tf.nn.leaky_relu, trainable= trainable)(net)
                net = tf.layers.Dense(1, trainable= trainable )(net)

        return net

In [0]:
def encoder( x,  trainable= True,scope = "encoder", dataset = "celeba", stage = 0, alpha = 0 ):
    with tf.variable_scope( scope, reuse = tf.AUTO_REUSE  ):
        if dataset == "celeba":
            size_list = FLAGS.size_list
            in_channel_list =  FLAGS.encoder_in_channel_list  
            out_channel_list = FLAGS.encoder_out_channel_list
            if stage > 0:
                with tf.variable_scope( "stage%d"%(stage), reuse=tf.AUTO_REUSE ):
                    net = fromRGB( x, trainable, in_channel_list[stage])
                    net = conv_layer( net, trainable, in_channel_list[stage], 3, (1,1), 0, "leaky_relu", 0.2, "conv1")
                    net = conv_layer( net, trainable, out_channel_list[stage], 3, (1,1), 0, "leaky_relu", 0.2, "conv2")
                    net = tf.layers.AveragePooling2D( 2,2,"same" )(net)
                with tf.variable_scope( "stage%d"%(stage-1), reuse=tf.AUTO_REUSE ):
                    net_1 = fromRGB( tf.layers.AveragePooling2D(2,2,"same")(x), trainable, in_channel_list[stage-1])
                    net = (1-alpha) * net_1 + alpha * net
            else:
                with tf.variable_scope( "stage0", reuse= tf.AUTO_REUSE):
                    net = fromRGB( x, trainable, in_channel_list[0])
            
            for stg in range( stage-1, 0, -1 ):
                with tf.variable_scope( "stage%d"%(stg), reuse= tf.AUTO_REUSE ):
                    net = conv_layer( net, trainable, in_channel_list[stg], 3, (1,1), 0, "leaky_relu", 0.2, "conv1")
                    net = conv_layer( net, trainable, out_channel_list[stg], 3, (1,1), 0, "leaky_relu", 0.2, "conv2")
                    net = tf.layers.AveragePooling2D( 2,2,"same" )(net)

            with tf.variable_scope( "stage0", reuse= tf.AUTO_REUSE):
                net = conv_layer( net, trainable, in_channel_list[0], 3, (1,1), 0, "leaky_relu", 0.2, "conv1")
                net = tf.layers.Flatten()(net)
                net = tf.layers.Dense( out_channel_list[0], activation= tf.nn.leaky_relu, trainable= trainable)(net)
                net = tf.layers.Dense( FLAGS.z_dim , trainable= trainable )(net)
            
        return net

### load model weights of the previous resolutions for initialization of the model of the current resolustion

In [0]:
def init_weights(scope_name, path):
    if path == None:
        return
    # look for checkpoint
    model_path = tf.train.latest_checkpoint(path)
    initializer_fn = None
    if model_path:
        # only restore variables in the scope_name scope
        variables_to_restore = slim.get_variables_to_restore(include=scope_name)
        # Create the saver which will be used to restore the variables.
        initializer_fn = slim.assign_from_checkpoint_fn(model_path, variables_to_restore)
    else:
        print("could not find the fine tune ckpt at {}".format(path))
        exit()
    def InitFn(scaffold,sess):
        initializer_fn(sess)
    return InitFn

## Building the model and define the behavior of training, evaluation and prediction

In [0]:
def model_fn(features, labels, mode, params):
    # Args:
    # features: This is the x-arg from the input_fn.
    # labels:   This is the y-arg from the input_fn,
    #           see e.g. train_input_fn for these two.
    # mode:     Either TRAIN, EVAL, or PREDICT
    # params:   User-defined hyper-parameters, e.g. learning-rate.
    
    dataset = params["dataset"]
    lr = params["learning_rate"]
    stage = params["stage"]

    """ Part I. create the model networks"""
    is_train = mode == tf.estimator.ModeKeys.TRAIN

    if is_train:
        x = features["x"]
    else:
        x = features

    # here x has a range [-1, 1]
    global_step = tf.train.get_global_step()
    # alpha_for_transition for stage0 is actually never used since stage0 doesnot require resolution transition
    alpha_for_transition = tf.clip_by_value(  tf.cast(global_step , tf.float32 ) / (FLAGS.train_steps_list[stage]/2), 0, 1  )

    size_list = FLAGS.size_list
    if stage > 0:
        x_low =  tf.image.resize_bilinear( x, [size_list[stage-1],size_list[stage-1]], align_corners= True)
        x_low_up =  tf.image.resize_bilinear( x_low, [size_list[stage],size_list[stage]], align_corners= True  )
        x = alpha_for_transition * x + ( 1-alpha_for_transition )* x_low_up
    
    random_z = tf.random.normal( [tf.shape(x)[0], FLAGS.z_dim ]  )
    gen_x = generator( random_z, trainable= is_train , dataset= dataset, stage=stage, alpha= alpha_for_transition )
    dis_x = discriminator( x, trainable= is_train, dataset=dataset, stage=stage, alpha= alpha_for_transition  )
    dis_gen_x = discriminator( gen_x, trainable= is_train, dataset=dataset, stage=stage , alpha= alpha_for_transition  )
    # for the encoder part
    recon_z = encoder( gen_x, trainable = is_train, dataset = dataset, stage=stage ,  alpha= alpha_for_transition )
    # for image reconstruction
    enc_z = encoder( x, trainable= False, dataset= dataset, stage=stage, alpha= alpha_for_transition  )
    recon_x = generator( enc_z, trainable= False, dataset= dataset, stage=stage , alpha= alpha_for_transition )

    # This is used to compute the gradient penalty
    epsilon = tf.random.uniform( [ tf.shape(x)[0],1,1,1 ], minval=0, maxval= 1 )
    interp_x = epsilon * x + (1-epsilon) * gen_x
    dis_interp_x = discriminator( interp_x, trainable= is_train, dataset=dataset, stage= stage , alpha= alpha_for_transition )
    gradient_x = tf.gradients( dis_interp_x, [ interp_x ]  )[0]
    gradient_penalty = tf.square( tf.sqrt( tf.reduce_sum( tf.square(gradient_x ),[1,2,3] ) + 1e-6) - 1  )
    LAMBDA = 10

    e_drift = 1e-3
    dis_drift_loss = e_drift * dis_x**2

    """load parameters from previous model"""
    # tf.train.init_from_checkpoint also require load the Adam optimizer's parameter, which may cause some error information sometimes, here we use scaffold_fn
    # same as init_from_checkpoint, scaffold_fn will be only executed once when there is no checkpoints for current model. If there is some checkpoints for current model,
    # this function is not executed, which means that it will not override some existing checkpoints' parameters (under testing ...)
    scaffold_fn = None
    if stage >0:
        previous_model_dir = FLAGS.model_dir+"stage%d"%(stage-1)
        scope_name_list =[]
        for stg in range( stage ):
            scope_name_list += [ "discriminator/stage%d/"%(stg), "generator/stage%d/"%(stg), "encoder/stage%d/"%(stg) ]
        def scaffold_fn():
            return tf.train.Scaffold( init_fn=  init_weights( scope_name_list, previous_model_dir )  )
    

    """Part II. define the loss and relative parameters for mode == TRAIN/EVAL/PREDICT"""
    ## compute loss
    loss_dis = dis_gen_x  - dis_x + LAMBDA * gradient_penalty + dis_drift_loss
    loss_gen = - dis_gen_x
    W_dis = dis_x - dis_gen_x
    loss_recon_z = tf.reduce_sum( tf.square(random_z - recon_z), [1])
    loss_reg_std_z = tf.abs( tf.math.reduce_std( recon_z ) -1 )

    ## operations for the training mode, define the optimizer, and reconfig it using tpu.CrossShardOptimizer
    if mode == tf.estimator.ModeKeys.TRAIN:
        # g_w = features["g_w"]
        loss_dis = tf.reduce_mean( loss_dis   )  
        loss_gen = tf.reduce_mean( loss_gen   )  #  *g_w)
        W_dis = tf.reduce_mean(W_dis)
        loss_recon_z = tf.reduce_mean(loss_recon_z  ) #  *g_w)
        loss_reg_std_z = loss_reg_std_z  # *g_w

        # Define the optimizer
        d_optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=0, beta2= 0.99 )
        g_optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=0, beta2= 0.99 )
        e_optimizer = tf.train.AdamOptimizer(learning_rate=lr, beta1=0, beta2= 0.99 )
        e2g_optimizer = tf.train.AdamOptimizer(learning_rate=lr/5, beta1=0, beta2= 0.99 )


        if FLAGS.use_tpu:
            d_optimizer = tf.tpu.CrossShardOptimizer(d_optimizer)
            g_optimizer = tf.tpu.CrossShardOptimizer(g_optimizer)
            e_optimizer = tf.tpu.CrossShardOptimizer(e_optimizer)
            e2g_optimizer = tf.tpu.CrossShardOptimizer(e2g_optimizer)

        with tf.control_dependencies( tf.get_collection( tf.GraphKeys.UPDATE_OPS )):
            d_op = d_optimizer.minimize( loss = loss_dis, var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,\
                                           scope="discriminator")  )
            g_op = g_optimizer.minimize(loss = loss_gen,  var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,\
                                           scope="generator"),global_step= global_step )
            e_op = e_optimizer.minimize(loss = loss_recon_z + loss_reg_std_z ,  var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,\
                                           scope="encoder"))
            e2g_op = e2g_optimizer.minimize(loss = loss_recon_z * 1e-2 ,  var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,\
                                           scope="generator"))
            # increment_step = tf.assign_add( tf.train.get_or_create_global_step(), 1)
            train_op = tf.group( [ d_op , g_op, e_op, e2g_op ] )
            # STRANGE: the definition of spec cannot be putinto with ...
            spec= tf.estimator.tpu.TPUEstimatorSpec(mode=mode, loss= W_dis ,train_op= train_op,  scaffold_fn = scaffold_fn  )
    ## for EVAL mode, the parameters eval_metrics takes a tuple or list of two elements. The first element is a callable function,
    ## The second element is a list of parameters. The return value of the callable function will be shown in the evaluatio results
    elif mode == tf.estimator.ModeKeys.EVAL:
        spec = tf.estimator.tpu.TPUEstimatorSpec(mode=mode, loss= tf.reduce_mean( W_dis), eval_metrics=(metric_fn, [loss_gen, loss_dis, W_dis, loss_recon_z ] ) )
    
    elif mode == tf.estimator.ModeKeys.PREDICT:
        input_z = tf.random.normal( shape=( tf.shape(random_z) ) )
        predictions = { "generated_images":  gen_x ,
                        "input_images": x,
                        "reconstructed_images": recon_x
                          }
        spec= tf.estimator.tpu.TPUEstimatorSpec( mode = mode, predictions = predictions )
        return spec

    
    return spec

## Create the TPUEstimator entity, and run the train / evaluate/ predict

In [0]:
model_for_stage = []

for stg in range( len(FLAGS.size_list) ):
    run_config = tf.estimator.tpu.RunConfig(
        model_dir=FLAGS.model_dir+"stage%d"%(stg),
        cluster=tf.distribute.cluster_resolver.TPUClusterResolver(FLAGS.tpu),
        session_config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True),
        tpu_config=tf.estimator.tpu.TPUConfig(FLAGS.iterations_per_loop),
        )

    model = tf.estimator.tpu.TPUEstimator(
                               model_fn=model_fn,
                               params = {"learning_rate": FLAGS.lr_list[stg]  , "dataset": FLAGS.dataset, "stage":stg },
                               config = run_config,
                               use_tpu= FLAGS.use_tpu,
                               train_batch_size=FLAGS.batch_size_list[stg]  ,
                               eval_batch_size=FLAGS.batch_size_list[stg] ,
                               predict_batch_size= 64,
                              ) 
    model_for_stage.append(model)

INFO:tensorflow:Using config: {'_model_dir': 'gs://reverse-vae-celeba-hq/model/stage0', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
log_device_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.76.77.186:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f7450a38278>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.76.77.186:8470', '_evaluation_master': 'grpc://10.76.77.186:8470', '_is_chief': True, '_num_ps_re

### Training

In [0]:
for stg in range(len(FLAGS.size_list)):
    model = model_for_stage[stg]
    model.train( input_fn = lambda params: train_input_fn( FLAGS.dataset, stg, params["batch_size"] ), max_steps=FLAGS.train_steps_list[stg]  )

INFO:tensorflow:Querying Tensorflow master (grpc://10.76.77.186:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 7055712508931361655)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 8594660463716365610)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 17415441877167209408)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 825608924856497278)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 8506710606270258419)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worke

### Evaluate

In [0]:
stage = 0
eval_result = model.evaluate(input_fn=lambda params: eval_input_fn( FLAGS.dataset,stage,params["batch_size"]), steps = 10)

In [0]:
eval_result

### Predict

In [0]:
stage =0
model = model_for_stage[stage]

pred_results = model.predict( input_fn=lambda params: predict_input_fn(FLAGS.dataset, stage, params["batch_size"] ) )
images = [ [result["generated_images"], result["input_images"] ,result["reconstructed_images"]  ] for result in pred_results  ]

generated_images = [ im[0] for im in images ]
generated_images = np.array( generated_images)

input_images = [ im[1] for im in images ]
input_images = np.array( input_images)

reconstructed_images = [ im[2] for im in images ]
reconstructed_images = np.array( reconstructed_images)


In [0]:
grid_imges = np.squeeze( make_grid(add_padding(generated_images[np.random.choice(generated_images.shape[0], 64, replace=False)]) ))
imageio.imwrite("random_generate.png",grid_imges )
if len(grid_imges.shape) <3:
    plt.gray()
plt.figure(dpi=200)
plt.imshow(grid_imges)
plt.show()

In [0]:
randidx = np.random.choice( input_images.shape[0], 64, replace= False  )
grid_imges = np.squeeze( make_grid(np.concatenate([ add_padding(input_images[ randidx]), add_padding(reconstructed_images[randidx]) ], axis =2 ) ) )
imageio.imwrite("recon.png", grid_imges)
plt.gray()
plt.figure(dpi=200)
plt.imshow(grid_imges)
plt.show()