# Training Super-Resolution Models 
## (for Bathymetry resolution enhancement)

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ModelCheckpoint
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os 
import pathlib
import glob
import tqdm

In [None]:
print("Tensorflow Version: ", tf.__version__)
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

-----

#### Datasets and Parameters

In [None]:
dataset = {'GEBCO_x2' :    {'HR_train' : 'datasets/GEBCO/x2/train/HR',
                             'LR_train' : 'datasets/GEBCO/x2/train/LR',
                             'HR_test' : 'datasets/GEBCO/x2/test/HR',
                             'LR_test' : 'datasets/GEBCO/x2/test/LR',
                            }, 
            
            'GEBCO_x4' :    {'HR_train' : 'datasets/GEBCO/x4/train/HR',
                             'LR_train' : 'datasets/GEBCO/x4/train/LR',
                             'HR_test' : 'datasets/GEBCO/x4/test/HR',
                             'LR_test' : 'datasets/GEBCO/x4/test/LR',
                            }, 
            
            'Pangaea_256' : {'HR_train' : 'datasets/Pangaea/256/train/HR',
                             'HR_test' : 'datasets/Pangaea/256/test/HR',
                            }, 
            
            'Pangaea_128' : {'HR_train' : 'datasets/Pangaea/128/train/HR',
                             'HR_test' : 'datasets/Pangaea/128/test/HR',
                            }, 
           }

In [None]:
#name = 'GEBCO_x2'
#name = 'GEBCO_x4'
name = 'Pangaea_256'
#name = 'Pangaea_128'

In [None]:
Params = {
    'batch_size' : 8,                   # Number of image samples used in each training step          
    'hr_dimension' : 256,               # Dimension of a High Resolution (HR) Image
    'scale' : 2,                        # Factor by which Low Resolution (LR) Images will be downscaled.
    'data_name': name,                  # Dataset name
    'trunk_size' : 23,                  # Number of Residual blocks used in Generator,
    'init_lr' : 0.0001,                 # Initial Learning rate for generator. 
    'disc_init_lr' : 0.00005,           # Initial learning rate for discriminator.
    'ph2_steps' : 50000,                # Number of steps required for phase-2 training
    'decay_ph2' : 0.5,                  # Factor by which learning rates are modified during phase-2 training    
    'lambda' : 0.005,                   # To balance adversarial loss during phase-2 training. 
    'eta' : 0.01,                       # To balance L1 loss during phase-2 training.    
}

In [None]:
model_name = 'models/GAN_Pangaea/Generator_'+ Params['data_name'] +'_x' + str(Params['scale'])

----- 

#### Input Pipeline

In [None]:
# normalize values from [min, max] to [0-1]
def normalize(image, idx): 
    min_val = min_vals[idx]
    max_val = max_vals[idx]
    return (image - (min_val)) / (max_val - min_val)

# normalize values from [min, max] to [0-1]
def normalize_test(image, idx): 
    min_val = test_min_vals[idx]
    max_val = test_max_vals[idx]
    return (image - (min_val)) / (max_val - min_val)

def load_dataset_with_degraded_LR(HR_file_path, is_val_ds, remove=False): 

    dataset_name = Params['data_name']
    img_dir=dataset[dataset_name][HR_file_path]
    all_files=os.listdir(img_dir)
    hr = [os.path.join(img_dir + "/" + i) for i in all_files]
    print('HR Training Samples: ', len(hr), 'Files found')
    
    hr_ds = [np.load(x) for x in hr]
    
    if is_val_ds == True: 
        global test_min_vals
        global test_max_vals
        test_min_vals = [x.min() for x in hr_ds]
        test_max_vals = [x.max() for x in hr_ds]
        hr_ds = [normalize_test(x, idx) for idx, x in enumerate(hr_ds)]
        
    else: 
        global min_vals
        global max_vals
        # get min and max values of each sample pair
        min_vals = [x.min() for x in hr_ds]
        max_vals = [x.max() for x in hr_ds]
        hr_ds = [normalize(x, idx) for idx, x in enumerate(hr_ds)]
        
        
    hr_ds = tf.cast(hr_ds, tf.float32)
    hr_ds = [tf.stack((x,)*3, axis=-1) for x in hr_ds]
    hr_ds = tf.convert_to_tensor(hr_ds)
    
    lr_degraded = [(tf.image.resize(tf.squeeze(x), 
                             [hr_ds[0].shape[0] // Params['scale'],hr_ds[0].shape[0] // Params['scale']], 
                             method=tf.image.ResizeMethod.BICUBIC) )
                    for x in hr_ds]
    
    lr_ds = tf.convert_to_tensor(lr_degraded)
    
    print("LR Dataset Shape: ", lr_ds.shape)
    print("HR Dataset Shape: ", hr_ds.shape)
    
    return lr_ds, hr_ds
    

In [None]:
## if LR file path == None than do degraded -- else use LR path with real 

In [None]:
def load_dataset_with_real_LR(HR_file_path, LR_file_path, is_val_ds, remove=False):
    
    img_dir=Params['data_name'][HR_file_path]
    all_files=os.listdir(img_dir)
    hr = [os.path.join(img_dir + "/" + i) for i in all_files]
    print('HR Training Samples: ', len(hr), 'Files found')
    
    hr_ds = [np.load(x) for x in hr]
    
    img_dir=Params['data_name'][LR_file_path]
    all_files=os.listdir(img_dir)
    lr = [os.path.join(img_dir + "/" + i) for i in all_files]
    print('LR Training Samples: ', len(lr), 'Files found')
    
    lr_ds = [np.load(x) for x in lr]
    
    if is_val_ds == True: 
        global test_min_vals
        global test_max_vals
        test_min_vals = [x.min() for x in hr_ds]
        test_max_vals = [x.max() for x in hr_ds]
        hr_ds = [normalize_test(x, idx) for idx, x in enumerate(hr_ds)]
        lr_ds = [normalize_test(x, idx) for idx, x in enumerate(lr_ds)]
        
    else: 
        global min_vals
        global max_vals
        # get min and max values of each sample pair
        min_vals = [x.min() for x in hr_ds]
        max_vals = [x.max() for x in hr_ds]
        hr_ds = [normalize(x, idx) for idx, x in enumerate(hr_ds)]
        lr_ds = [normalize(x, idx) for idx, x in enumerate(lr_ds)]
    

    hr_ds = tf.cast(hr_ds, tf.float32)
    hr_ds = [tf.stack((x,)*3, axis=-1) for x in hr_ds]
    hr_ds = tf.convert_to_tensor(hr_ds)

    lr_ds = tf.cast(lr_ds, tf.float32)
    lr_ds = [tf.stack((x,)*3, axis=-1) for x in lr_ds]
    lr_ds = tf.convert_to_tensor(lr_ds)
    
    print("LR Dataset Shape: ", lr_ds.shape)
    print("HR Dataset Shape: ", hr_ds.shape)
    
    return lr_ds, hr_ds

In [None]:
def make_dataset(lr_train_ds, hr_train_ds, repeat=False): 
    
    dataset = tf.data.Dataset.from_tensor_slices((lr_train_ds, hr_train_ds))
    
    if repeat == True: 
        dataset = dataset.cache().repeat()
    else: 
        # no .repeat() for no infinite dataset
        dataset = dataset.cache()

    bs = Params['batch_size']
    shuffle = True

    if shuffle: 
        dataset = dataset.shuffle(buffer_size=10000, reshuffle_each_iteration=True)

    dataset = (dataset.batch(bs, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE))
    
    return dataset

### Generator

In [None]:
def conv_block(input, filters, activation=True): 
    h = layers.Conv2D(filters, kernel_size=[3,3], 
                      kernel_initializer='he_normal', bias_initializer='zeros', 
                      strides=[1,1], padding='same', use_bias=True)(input)
    if activation: 
        h = layers.LeakyReLU(0.2)(h)
    return h

def small_block(input): 
    h1 = conv_block(input, 32)
    h1 = layers.Concatenate()([input, h1])
    
    h2 = conv_block(h1, 32)
    h2 = layers.Concatenate()([h1, h2])
    
    h3 = conv_block(h2, 32)
    h3 = layers.Concatenate()([h2, h3])
    
    h4 = conv_block(h3, 32)
    h4 = layers.Concatenate()([h3, h4])
    
    h5 = conv_block(h4, 64, activation=False)
    
    h5 = layers.Lambda(lambda x:x * 0.2)(h5)
    h = layers.Add()([h5, input])
    
    return h

def large_block(input): 
    h = small_block(input)
    h = small_block(h)
    h = small_block(h)
    h = layers.Lambda(lambda x:x *0.2)(h)
    out = layers.Add()([h, input])
    return out

def upsample_block(x, filters):
    x = layers.UpSampling2D()(x)
    x = conv_block(x, 64, activation=True)
    return x
    
def generator_network(filter=64, trunk_size=Params['trunk_size'], scale=Params['scale'], out_channels=3):
    lr_input = layers.Input(shape=(None, None, 3))
    x = layers.Conv2D(filter, kernel_size=[3,3], kernel_initializer='he_normal', bias_initializer='zeros', strides=[1,1], padding='same', use_bias=True)(lr_input)
    x = layers.LeakyReLU(0.2)(x)
    
    ref = x
    for i in range(trunk_size): 
        x = large_block(x)
    
    x = layers.Conv2D(filter, kernel_size=[3,3], kernel_initializer='he_normal', bias_initializer='zeros', strides=[1,1], padding='same', use_bias=True)(x)
    x = layers.Add()([x, ref])
    
    if scale == 2: 
         x = upsample_block(x, filter)
    if scale == 4: 
        x = upsample_block(x, filter)
        x = upsample_block(x, filter)
        
    if scale == 8 : 
        x = upsample_block(x, filter)
        x = upsample_block(x, filter)
        x = upsample_block(x, filter)
    
    x = layers.Conv2D(filter, kernel_size=[3,3], kernel_initializer='he_normal', bias_initializer='zeros', strides=[1,1], padding='same', use_bias=True)(x)
    x = layers.LeakyReLU(0.2)(x)
    
    hr_output = layers.Conv2D(out_channels, kernel_size=[3,3], kernel_initializer='he_normal', bias_initializer='zeros', strides=[1,1], padding='same', use_bias=True)(x)
    
    model = tf.keras.models.Model(inputs=lr_input, outputs=hr_output)
    return model

-----

## Pre-Training Generator

Dataset with degraded LR 

In [None]:
lr_train, hr_train = load_dataset_with_degraded_LR('HR_train', is_val_ds=False)
train_dataset = make_dataset(lr_train, hr_train)

lr_val, hr_val = load_dataset_with_degraded_LR('HR_test', is_val_ds=True)

Dataset with real LR

In [None]:
# lr_train, hr_train = load_dataset_with_real_LR('HR_train', 'LR_train', is_val_ds=False)
# train_dataset = make_dataset(lr_train, hr_train)

# lr_val, hr_val = load_dataset_with_real_LR('HR_test', 'LR_test' is_val_ds=True)

In [None]:
import time
class TimeHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, batch, logs={}):
        self.epoch_time_start = time.time()

    def on_epoch_end(self, batch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)

In [None]:
# use mirrored strategy to use all available GPUs
mirrored_strategy = tf.distribute.MirroredStrategy()

# compile model
with mirrored_strategy.scope():
    generator = generator_network()
    #generator_200 = tf.keras.models.load_model(Params['model_dir'] + '/256_Generator/256_x4_200epochs_lr5_bs8')
    time_callback = TimeHistory()
    es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=75)
    best_model_name = Params['data_name'] +'_'+ str(Params['hr_dimension']) + '_x' + str(Params['scale']) + '_best_model.h5'
    mc = ModelCheckpoint(best_model_name, monitor='val_loss', mode='min', verbose=1, save_best_only=True)
    
    generator.compile(
        loss=tf.keras.losses.MeanAbsoluteError(), 
        optimizer= keras.optimizers.Adam(lr=Params['init_lr'], beta_1 = 0.9, beta_2 = 0.999),
        metrics=[
            tf.keras.metrics.MeanAbsoluteError(),
        ]
    )

In [None]:
generator.summary()

In [None]:
def plot_history(history): 
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title(Params['data_name'] +' ('+ str(Params['hr_dimension']) + ' x' + str(Params['scale']) + ') model loss')
    plt.ylabel('MAE')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper right')
    plt.savefig('history_plots/'+ Params['data_name'] +'_'+ str(Params['hr_dimension']) + '_x' + str(Params['scale']) + '_model_loss.png')
    plt.show()

In [None]:
def print_training_details(history, times):
    with open("history_plots/training_details" + name + "_x" + Params['scale'] + ".txt", "a") as f:

        print(f"Training on 4 GPUs with tf.mirroredStrategy \
        (batch size {Params['batch_size']/4} per GPU = batch size {Params['batch_size']}). \
        \nLearning rate: {Params['init_lr']}", file=f)

        print(f"After epoch {np.argmin(history.history['val_loss'])+1} the minimal validation loss \
        {min(history.history['val_loss']):.4f} was reached.", file=f)

        print(f"Time in hours for training the model for {np.argmin(history.history['val_loss'])+1} epochs: {(times[1]*(np.argmin(history.history['val_loss'])+1))/60/60} h.", file=f)

In [None]:
history = generator.fit(train_dataset, 
                        validation_data=(lr_val, hr_val), 
                        epochs=1500, 
                        verbose=2, 
                        callbacks=[es,mc,time_callback])

plot_history(history)
print_training_details(history, time_callback.times)
model = keras.models.load_model(best_model_name)
model.save(model_name) 

# GAN

In [None]:
lr_train, hr_train = load_dataset_with_degraded_LR('HR_train', is_val_ds=False)
train_dataset = make_dataset(lr_train, hr_train, repeat=True)

lr_val, hr_val = load_dataset_with_degraded_LR('HR_test', is_val_ds=True)

In [None]:
# lr_train, hr_train = load_dataset_with_real_LR('HR_train', 'LR_train', is_val_ds=False)
# train_dataset = make_dataset(lr_train, hr_train, repeat=True)

# lr_val, hr_val = load_dataset_with_real_LR('HR_test', 'LR_test' is_val_ds=True)

#### Dicriminator

In [None]:
def _conv_block_d(x, out_channel):
    x = layers.Conv2D(out_channel, 3,1, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Conv2D(out_channel, 4,2, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    return x

def discriminator_network(filters = 64, training=True):
    img = layers.Input(shape = (Params['hr_dimension'], Params['hr_dimension'], 3))
    
    x = layers.Conv2D(filters, [3,3], 1, padding='same', use_bias=False)(img)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = layers.Conv2D(filters, [3,3], 2, padding='same', use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.2)(x)

    x = _conv_block_d(x, filters *2)
    x = _conv_block_d(x, filters *4)
    x = _conv_block_d(x, filters *8)
  
    x = layers.Flatten()(x)
    x = layers.Dense(100)(x)
    x = layers.LeakyReLU(alpha=0.2)(x)
    x = layers.Dense(1)(x)

    model = tf.keras.models.Model(inputs = img, outputs = x)
    return model

#### Loss Functions

In [None]:
def pixel_loss(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    return tf.reduce_mean(tf.reduce_mean(tf.abs(y_true - y_pred), axis = 0))

# Function for calculating perceptual loss
def vgg_loss(weight=None, input_shape=None):
    vgg_model = tf.keras.applications.vgg19.VGG19(
        input_shape=input_shape, weights=weight, include_top=False
    )

    for layer in vgg_model.layers:
        layer.trainable = False

    vgg_model.get_layer("block5_conv4").activation = lambda x: x
    vgg = tf.keras.Model(
      inputs=[vgg_model.input],
      outputs=[vgg_model.get_layer("block5_conv4").output])

    def loss(y_true, y_pred):
        return tf.compat.v1.losses.absolute_difference(vgg(y_true), vgg(y_pred))

    return loss

def relativistic_discriminator_loss(discriminator_real_outputs,
                                    discriminator_gen_outputs,
                                    scope=None):
    """Relativistic Average GAN discriminator loss."""

    with tf.compat.v1.name_scope(
        scope,
        'relativistic_discriminator_loss',
        values=[discriminator_real_outputs, discriminator_gen_outputs]):

        def get_logits(x, y):
              return x - tf.reduce_mean(y)

        real_logits = get_logits(discriminator_real_outputs,
                                 discriminator_gen_outputs)
        gen_logits = get_logits(discriminator_gen_outputs,
                                discriminator_real_outputs)

        real_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.ones_like(real_logits), logits=real_logits))
        gen_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(gen_logits), logits=gen_logits))

    return real_loss + gen_loss


def relativistic_generator_loss(discriminator_real_outputs,
                                discriminator_gen_outputs,
                                scope=None):
    """Relativistic Average GAN generator loss."""
    
    with tf.compat.v1.name_scope(
        scope,
        'relativistic_generator_loss',
        values=[discriminator_real_outputs, discriminator_gen_outputs]):

        def get_logits(x, y):
            return x - tf.reduce_mean(y)
        
        real_logits = get_logits(discriminator_real_outputs,
                                 discriminator_gen_outputs)
        gen_logits = get_logits(discriminator_gen_outputs,
                                discriminator_real_outputs)
        
        real_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(real_logits), logits=real_logits))
        gen_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.ones_like(gen_logits), logits=gen_logits))

    return real_loss + gen_loss

#### Training of GAN 

In [None]:
# use mirrored strategy to use all available GPUs
mirrored_strategy = tf.distribute.MirroredStrategy()
dist_dataset = mirrored_strategy.experimental_distribute_dataset(train_dataset)

In [None]:
disc_model_name = 'models/GAN_Pangaea/Discriminator_Pangaea_256_x8_49000'

In [None]:
# compile model
with mirrored_strategy.scope():
    
    optimizer = tf.optimizers.Adam(
        learning_rate = 0.0002, 
        beta_1 = 0.9,
        beta_2 = 0.99)
    
    
    # load pretrained generator
    generator = tf.keras.models.load_model(model_name)
    
    #discriminator = discriminator_network()
    discriminator = tf.keras.models.load_model(disc_model_name)
    
    g_optimizer = optimizer
    g_optimizer.learning_rate.assign(Params['disc_init_lr'])
    d_optimizer = optimizer
    
    checkpoint = tf.train.Checkpoint(G=generator,
                                 D = discriminator,
                                 G_optimizer=g_optimizer,
                                 D_optimizer=d_optimizer)
   # local_device_option = tf.train.CheckpointOptions(experimental_io_device="/job:localhost")
    
    perceptual_loss = vgg_loss(
        weight = "imagenet",
        input_shape = [Params['hr_dimension'], Params['hr_dimension'], 3])
    
    gen_metric = tf.keras.metrics.Mean()
    disc_metric = tf.keras.metrics.Mean()
    psnr_metric = tf.keras.metrics.Mean()
    metric = tf.keras.metrics.Mean()
    

In [None]:
def train_step(image_lr, image_hr):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        #image_lr, image_hr = inputs
        fake = generator(image_lr)

        percep_loss = tf.reduce_mean(perceptual_loss(image_hr, fake))
        l1_loss = pixel_loss(image_hr, fake) 

        real_logits = discriminator(image_hr) 
        fake_logits = discriminator(fake) 

        loss_RaG = relativistic_generator_loss(real_logits,
                                               fake_logits) 
        disc_loss = relativistic_discriminator_loss(real_logits,
                                                    fake_logits) 

        gen_loss = percep_loss + Params['lambda'] * loss_RaG + Params['eta'] * l1_loss

        gen_loss = gen_loss / 2 #Params['batch_size'] # hier wegen distributed evtl. noch durch anzahl der GPUs teilen
        disc_loss = disc_loss / 2 #Params['batch_size']     
        psnr_loss = tf.image.psnr(tf.clip_by_value(fake, 0., 1.), tf.clip_by_value(image_hr, 0., 1.), max_val = 1.0)
        #rmse_loss = # rmse loss

        disc_metric(disc_loss) 
        gen_metric(gen_loss)
        psnr_metric(psnr_loss)

        disc_grad = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(disc_grad, discriminator.trainable_variables))

        gen_grad = gen_tape.gradient(gen_loss, generator.trainable_variables) 
        g_optimizer.apply_gradients(zip(gen_grad, generator.trainable_variables))

    return [disc_loss, gen_loss, psnr_loss]

In [None]:
def reduce_list(l: list):
    return [mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, v, axis=None) for v in l]

@tf.function
def distributed_train_step(x, y):
    results = mirrored_strategy.run(train_step, args=(x, y))
    results = reduce_list(results)
    return results

In [None]:
def plot_gan_history(history, step, label): 
    plt.plot(history , label = label)
    plt.legend()
    plt.savefig('history_plots/'+ Params['data_name'] + '/GAN_x'+ str(Params['scale']) + '_' + label + '_' + str(step_count)+ '.png')
    plt.close()

# **change directory**

In [None]:
iterator = iter(dist_dataset)

#step_count = 0
#decay_step = [9000, 30000, 50000]

step_count = 49000
decay_step = [50000]

disc_hist = []
gen_hist = []
psnr_hist = []

while step_count < Params['ph2_steps']:

    lr, hr = next(iterator)

    # if tf.train.latest_checkpoint(Params['ckpt_dir']): 
    #     checkpoint.restore(tf.train.latest_checkpoint(Params['ckpt_dir']))
        
    disc_loss, gen_loss, psnr_loss  = distributed_train_step(lr, hr)
    
    disc_hist.append(disc_loss)
    gen_hist.append(gen_loss)
    psnr_hist.append(psnr_loss)

    if step_count % 1000 == 0:
        print("step {}".format(step_count) + "   Generator Loss = {}   ".format(gen_metric.result()) + 
              "Disc Loss = {}".format(disc_metric.result()) + "   PSNR : {}".format(psnr_metric.result()))

        
        os.makedirs('models/GAN_Pangaea/Generator_' + Params['data_name'] + '_x' + str(Params['scale']) + '_' + str(step_count), exist_ok = True)
        os.makedirs('models/GAN_Pangaea/Discriminator_' + Params['data_name'] + '_x' + str(Params['scale']) + '_' + str(step_count), exist_ok = True)

        generator.save('models/GAN_Pangaea/Generator_' + Params['data_name'] + '_x' + str(Params['scale']) + '_' + str(step_count))
        discriminator.save('models/GAN_Pangaea/Discriminator_' + Params['data_name'] + '_x' + str(Params['scale']) + '_' + str(step_count))
        
        if step_count != 0 : 
            plot_gan_history(disc_hist, step_count, 'Discriminator_loss')
            plot_gan_history(gen_hist, step_count, 'Generator_loss')
            plot_gan_history(psnr_hist, step_count, 'PSNR')

    #checkpoint.write(Params['ckpt_dir'], options=local_device_option)

    if step_count >= decay_step[0]:
        decay_step.pop(0)
        g_optimizer.learning_rate.assign(
            g_optimizer.learning_rate * Params['decay_ph2'])
        d_optimizer.learning_rate.assign(
            d_optimizer.learning_rate * Params['decay_ph2'])

    step_count+=1
