# ESL estimation with GAN

This notebook allows to train a GAN network to synthesize stray-light images based on network architectures specified via YAML configuration files.

In [1]:
import datetime
import sys
from pathlib import Path
import os
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from configuredEstimator import TensorDict
from pathlib import Path

tf.reset_default_graph()
tf.logging.set_verbosity(tf.logging.INFO)

sys.path.append("/jup/projects/mtgtools2/")  # ensure we find the MTG libraries
import pymtg.fci.esl.training_data as esl
import pymtg

MNT_DIR = Path("/scratch/andreu/")
RETINAS_JSON = Path(pymtg.__file__).parent.parent / "test/FCI_ESL/pyesl/config/retinas.json"
ESL_PATH = MNT_DIR / "ESL/nc/output/20170615T152646/ch123/"
FCI_PATH = MNT_DIR / "ESL/nc/swaths/"

def collapse_scene(scene, step=1):
    col_scene = scene[0]
    for s in scene[1:]:
        col_scene = np.insert(col_scene, -step, s.transpose()[-step], axis=1)  # create single input scene.
    return col_scene

def plot_scene(scene):
    plt.imshow(collapse_scene(scene))
    plt.title('Input scene')
    plt.show()

def plot_esl(esl, title='ESL'):
    plt.imshow(esl.T)
    plt.title(title)
    plt.show()
    
def get_model_name(name='Unnamed_conf_model'):
    d = datetime.datetime.now()
    return "{}_{}".format(name, d.strftime("%Y%m%dT%H%M%S"))
                          
def get_model_dir(model_dir, **kwargs):
    return "{}/{}".format(model_dir, get_model_name(**kwargs))
                          
def input_fn_np(size, train_or_test='train', collapsed=False, num_batches=1, mode=tf.estimator.ModeKeys.TRAIN, plot_data=False, verbose=False):
    """
    Input function that serves data to the estimator.
    If not collapsed it will return x={features:(size*num_batches, 129, 186), y=(size*num_batches, 113).
    If collapsed, it will return x={features:(num_batches, 129, 186+size)}, y=(num_batches, 113xsize)
    The argument train_or_test should be set to train for the moment, as test is not validated.
    """
    if verbose:
        print("Calling input_fn with collapsed={}, num_batches={}, size={}".format(collapsed, num_batches, size))
    files = esl.prepare_training_data(FCI_PATH, ESL_PATH, 0, str(RETINAS_JSON))
    out_batch = list(files[train_or_test].batches("ch123", size, 20000, num_batches=num_batches))  # use always 'train', 'test' is not validated
    if collapsed:
        features = np.array([collapse_scene(out_batch[i][0]) for i in range(num_batches)])
        labels = np.array([out_batch[i][1].flatten() for i in range(num_batches)])
    else:
        features = np.concatenate([out_batch[i][0] for i in range(len(out_batch))], axis=0)  # flatten into a (num_batches*size, 129, 113)
        labels = np.concatenate([out_batch[i][1] for i in range(len(out_batch))], axis=0)
    if plot_data:  # for validation purposes during prediction, will plot only the first image
        if collapsed:
            for f in features:
                plt.imshow(f)
                plt.show()
            for l in labels:
                plt.imshow(l.reshape(out_batch[0][1].shape))
                plt.show()
        else:
            plot_scene(features)
            plot_esl(labels)
    return features, labels

# TensorFlow Graph definition

Since this network is a Generative Adversarial Network (GAN), the graph is composed by two main elements:
* A **generator** network, which function is to create synthesized stray-light images from earth images.
* A **discriminator** network, which function is to distinguish apart the synthesized stray-light images from the real stray-light images.

**NOTE**: the GAN model works only using the non-collapsed mode for the ESL batches. The collapsed mode is not supported, as it has anyway given poorer results.

The flow is controlled via the following parameters:

* SIZE is the size of the ESL images (SIZE x NUM_BATCHES, 113).
* TRAIN_STEPS is the number of training iterations.
* TEST_STEPS is the number of test iterations, that will be plotted after the training has finalized.
* MODEL_GEN is the name of the YAML configuration file that defines the architecture of the generator.
* MODEL_DIS is the name of the YAML configuration file that defines the architecture of the discriminator.
* NUM_BATCHES is the number of batches that are fed at every iteration. It is redundant with the SIZE parameters as the total size of the ESL is (SIZE x NUM_BATCHES, 113). It is advised to left it set to 1.
* VERBOSE, if set to True the process will inform about the shapes of the tensors created during graph definition.
* PRINT_ITER controls the frequency of the messages displaying the loss of generator and discriminator. It will print a message every PRINT_ITER training steps.
* SAVE_N_CHKPT controls the total number of checkpoints created. A final one is always created.

In [2]:
SIZE = 100
TRAIN_STEPS = 100
TEST_STEPS = 2  # number of test images to predict
MODEL_GEN = 'CONV3.yml'
MODEL_DIS = 'DISC_GAN.yml'
NUM_BATCHES = 1
VERBOSE = False
PRINT_ITER = 2  # print loss every N iterations
SAVE_N_CHKPT = 5  # how many checkpoints to save

In [None]:
import tensorflow as tf

COLLAPSED = False

def get_gan_config_name():
    coll = 'collapsed' if COLLAPSED else 'not_collapsed'
    return 'GEN_' + MODEL_GEN + '_DIS_' + MODEL_DIS + '_' + coll + '_size_' + str(SIZE) + '_batches_' + str(NUM_BATCHES) + '_train_steps_' + str(TRAIN_STEPS)

def get_model_name(name='Unnamed_conf_model'):
    d = datetime.datetime.now()
    return "{}_{}".format(name, d.strftime("%Y%m%dT%H%M%S"))
                          
def get_model_dir(model_dir, **kwargs):
    return Path(model_dir) / get_model_name(**kwargs)

def generator(z, reuse=False, verbose=False):
    td = TensorDict(MODEL_GEN, z, prefix='GEN', batch_dim=0, create_scope=True, verbose=verbose)
    return td.last()

def discriminator(x, reuse=False, verbose=False):    
    td = TensorDict(MODEL_DIS, x, prefix='DISC', batch_dim=None, create_scope=True, reuse=reuse, verbose=verbose)
    return td.last()

earth_images, esl_images = input_fn_np(SIZE, num_batches=NUM_BATCHES, plot_data=False, collapsed=COLLAPSED)
X = tf.placeholder(tf.float32, esl_images.shape)
Z = tf.placeholder(tf.float32, earth_images.shape)
G_sample = generator(Z, verbose=VERBOSE)  # fake samples
r_logits = discriminator(X, verbose=VERBOSE) # return the variable scope to be reused
f_logits = discriminator(G_sample, reuse=True, verbose=VERBOSE)  # reuse is true

# loss functions
# discriminator loss is log(D(x)) + log(1-D(G(z)))
disc_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=r_logits, labels=tf.ones_like(r_logits)) +
                           tf.nn.sigmoid_cross_entropy_with_logits(logits=f_logits, labels=tf.zeros_like(f_logits)))

# generator loss is the MSE
gen_loss = tf.reduce_mean(tf.losses.mean_squared_error(labels=X, predictions=G_sample))
# generator loss is log(G(z))
#gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=f_logits, labels=tf.ones_like(f_logits)))

# optimizers
# we collect the variables to be updated using scope and var_list
gen_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="GEN")
gen_step = tf.train.RMSPropOptimizer(learning_rate=0.0001).minimize(gen_loss, var_list=gen_vars)  # G train step
disc_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="DISC")
disc_step = tf.train.RMSPropOptimizer(learning_rate=0.0001).minimize(disc_loss, var_list=disc_vars)  # D train step

# summary prep
tf.summary.scalar('disc_loss', disc_loss)
tf.summary.scalar('gen_loss', gen_loss)
merged = tf.summary.merge_all()
saver = tf.train.Saver()

# run session
model_dir = get_model_dir("/scratch/tensorboard", name=get_gan_config_name())
chkpt_dir = model_dir / 'saved_checkpoints'
print("Saving summaries and checkpoints on {}".format(model_dir))
with tf.Session() as sess:
    tf.global_variables_initializer().run(session=sess)
    train_writer = tf.summary.FileWriter(str(model_dir), sess.graph)
    for i in range(TRAIN_STEPS):
        Z_batch, X_batch = input_fn_np(SIZE, num_batches=NUM_BATCHES, plot_data=False, collapsed=COLLAPSED)
        _, dloss, summary = sess.run([disc_step, disc_loss, merged], feed_dict={X:X_batch, Z:Z_batch})
        _, gloss = sess.run([gen_step,gen_loss], feed_dict={Z: Z_batch, X: X_batch})
        train_writer.add_summary(summary, i)
        if i % PRINT_ITER == 0:        
            print("Run {}, Generator loss = {}, Discriminator loss = {}".format(i, gloss, dloss))
        if i % (TRAIN_STEPS/SAVE_N_CHKPT) == 0:
            save_path = saver.save(sess, str(chkpt_dir / '{}.ckpt'.format(i)))            
    save_path = saver.save(sess, str(chkpt_dir / 'final.ckpt'))  
    
    # prediction test
    for j in range(TEST_STEPS):
        Z_batch, X_batch = input_fn_np(SIZE, num_batches=NUM_BATCHES, plot_data=False, collapsed=COLLAPSED)
        esl_pred =  sess.run(G_sample, feed_dict={Z:Z_batch})
        plot_scene(Z_batch)
        plot_esl(X_batch, title='Real ESL')
        plot_esl(esl_pred, title='Predicted ESL')

Saving summaries and checkpoints on /scratch/tensorboard/GEN_CONV3.yml_DIS_DISC_GAN.yml_not_collapsed_size_100_batches_1_train_steps_100_20180823T143140
Run 0, Generator loss = 0.0009325678693130612, Discriminator loss = 1.3862943649291992
Run 2, Generator loss = 0.001425484661012888, Discriminator loss = 1.3862943649291992
Run 4, Generator loss = 0.0015025836182758212, Discriminator loss = 1.3862943649291992
Run 6, Generator loss = 8.599522698204964e-05, Discriminator loss = 1.3862943649291992
Run 8, Generator loss = 0.0016339042922481894, Discriminator loss = 1.3862943649291992
Run 10, Generator loss = 0.001478342222981155, Discriminator loss = 1.3862943649291992
Run 12, Generator loss = 0.0008007189608179033, Discriminator loss = 1.3862943649291992
Run 14, Generator loss = 0.000669955275952816, Discriminator loss = 1.3862943649291992
