In [None]:
import datetime
import os
import random
random.seed(12435345)
import sys
import time
sys.path.append(os.path.abspath("../.."))

import numpy as np
import tensorflow as tf
from PIL import Image

from molanet.models.cgan_pix2pix import Pix2PixModel
from molanet.models.cgan_pix2pix import IMAGE_SIZE
from molanet.poc_utils import load_image,get_image_batch,transform_batch,save_ndarrays_asimage,save

## Helpers

In [None]:
def sample_model(filenames: [str], batch_size: int, epoch: int, sess: tf.Session, source_dir: str, target_dir: str,
                 model: Pix2PixModel,
                 sample_dir: str, max_samples: int = 3):
    nsamples = min(batch_size, max_samples)

    batch = get_image_batch(batch_size, filenames, source_dir, target_dir)
    batch_src, batch_target = transform_batch(batch)

    sample, d_loss, g_loss = sess.run(
        [model.fake_B, model.d_loss, model.g_loss],
        feed_dict={model.real_data_source: batch_src,
                   model.real_data_target: batch_target}
    )

    size = sample.shape[2]
    zdim = batch_src.shape[-1]

    sample = tf.squeeze(sample).eval()[:nsamples]
    batch_src = batch_src[:nsamples]
    batch_target = batch_target[:nsamples]

    # from shape [nsamples,size,size,1] : tf.tensor to shape [size*nsamples,size] : ndarray
    sample = (np.reshape(sample, [size * nsamples, size]) + 1.0) / 2.0 * 255
    original_source = (tf.reshape(batch_src, [size * nsamples, size, zdim]).eval() + 1.0) / 2.0 * 255
    original_target = (tf.reshape(tf.squeeze(batch_target), [size * nsamples, size]).eval() + 1.0) / 2.0 * 255
    sample_error = np.absolute(original_target - sample)

    save_ndarrays_asimage(os.path.join(sample_dir, 'sample_%d.png' % epoch), original_source, sample,
                          original_target, sample_error)
    print("[Sample] d_loss: {:.8f}, g_loss: {:.8f}".format(d_loss, g_loss))


## Training

In [None]:
tf.reset_default_graph()
source_dir = r'../../../pix2pix-poc-data/training/source'
target_dir = r'../../../pix2pix-poc-data/training/target'
sourcefiles = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))]

now = datetime.datetime.now()
sample_dir = "./samples/sample-%d-%d-%d--%02d%02d" % (
    now.day, now.month, now.year, now.hour, now.minute)  # Generated samples
checkpoint_dir = "../logs/pix2pix"  # Model
if not os.path.exists('./samples'):
    os.mkdir('./samples')
if not os.path.exists(sample_dir):
    os.mkdir(sample_dir)

#Parameters
batch_size = 5
size = IMAGE_SIZE
num_feature_maps = 64 #feature maps on first conv layer
L1_lambda_generator = 100
iterations = 50000
d_updates = 1 #number of training sessions for discriminator per iteration
g_updates = 1 # generator training sessions
d_learning_rate = 0.0002
g_learning_rate = 0.0002
d_beta1 = 0.5
g_beta1 = 0.5

#non-learning related config
restore_iteration = None
use_random_image_as_sample = True
max_samples = 5

with tf.Session() as sess:
    model = Pix2PixModel(batch_size=batch_size, 
                        image_size=size, 
                        src_color_channels=3, 
                        target_color_channels=1,
                        g_l1_lambda=L1_lambda_generator)

    saver = tf.train.Saver()
    if restore_iteration is not None and restore_iteration > 0:
        iration_start = restore_iteration + 1
        checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
        checkpoint_name = os.path.basename(checkpoint.model_checkpoint_path)
        print('checkpoint_name=' + str(checkpoint_name))
        saver.restore(sess, os.path.join(checkpoint_dir, checkpoint_name))
    else:
        iteration_start = 0

    # Optimizers
    disc_optim = tf.train.AdamOptimizer(learning_rate=d_learning_rate, beta1=d_beta1)
    gen_optim = tf.train.AdamOptimizer(learning_rate=d_learning_rate, beta1=g_beta1)

    disc_update_step = disc_optim.minimize(model.d_loss, var_list=model.d_vars)
    gen_update_step = gen_optim.minimize(model.g_loss, var_list=model.g_vars)

    sess.run(tf.global_variables_initializer())

    # logging
    #image_sum = image_summary(model, max_image_outputs = 5)
    writer = tf.summary.FileWriter(checkpoint_dir, sess.graph)
    g_sum = tf.summary.merge([model.d__sum,
                              model.d_loss_fake_sum, 
                              model.g_loss_sum])
    
    d_sum = tf.summary.merge([model.d_sum, model.d_loss_real_sum, model.d_loss_sum])

    start_time = time.time()
    for iteration in range(iteration_start, iterations):
        batch = get_image_batch(batch_size, sourcefiles,
                                source_dir=source_dir,
                                target_dir=target_dir)
        batch_src,batch_target = transform_batch(batch)

        #as proposed in glsgan paper update discriminator N time every iteration
        for i in range(0,d_updates):
            _, summary_str = sess.run([disc_update_step, d_sum],
                                      feed_dict={model.real_data_source: batch_src,
                                                 model.real_data_target: batch_target})
            writer.add_summary(summary_str, iteration)

        # Update G network
        for i in range(0,g_updates):
            _, summary_str = sess.run([gen_update_step, g_sum],
                                  feed_dict={model.real_data_source: batch_src,
                                             model.real_data_target: batch_target})
            writer.add_summary(summary_str, iteration)

        #if iteration % 50 == 1:
            # gib nice picture output :) output directly to tensorboard
            sample_model(sourcefiles, batch_size, iteration, sess, source_dir, target_dir, model, sample_dir,
                        max_samples=max_samples)

        if iteration % 500 == 2:
            save(sess, saver, checkpoint_dir, iteration)