In [43]:
import datetime
import os
import random
import sys
import time
import importlib
sys.path.append(os.path.abspath("../.."))

import numpy as np
import tensorflow as tf
from PIL import Image

from molanet.models.glsgan import GlsGANModel
from molanet.models.glsgan import IMAGE_SIZE
import molanet.operations as ops

## Image Handling

In [44]:
def load_image(name: str, source_dir, target_dir, size=IMAGE_SIZE):
    def transformImageNameSource(name):
        return os.path.join(source_dir, name)

    def transformImageNameTarget(name: str):
        name = name.replace('.jpg', '_Segmentation.png')
        return os.path.join(target_dir, name)

    source_image = Image.open(transformImageNameSource(name))
    target_image = Image.open(transformImageNameTarget(name))

    # TODO think about proper resizing... is dis hacky? I don't know
    size = size, size
    source = source_image.resize(size, Image.BICUBIC)
    target = target_image.resize(size, Image.NEAREST)
    target = target.convert('1')  # to black and white

    return np.array(source).astype(np.float32), np.array(target).astype(np.float32)


def get_image_batch(batch_size, source_file_names, source_dir, target_dir) -> [np.ndarray, np.ndarray]:
    # TODO chances are we don't get fucked by rng
    indices = [random.randint(0, len(source_file_names) - 1) for _ in range(batch_size)]
    images = [load_image(source_file_names[i], source_dir, target_dir) for i in indices]
    return images

def transform_batch(image_batch):
    batch_src, batch_target = image_batch[0]
    batch_src = (batch_src / 255.0 - 0.5) * 2.0  # Transform into range -1, 1
    batch_target = (batch_target - 0.5) * 2.0  # Transform into range -1, 1

    batch_src = np.array(batch_src).astype(np.float32)[None, :, :, :]
    batch_target = np.array(batch_target).astype(np.float32)[None, :, :, None]
    
    if(len(image_batch) > 1):
        iterimages = iter(image_batch)
        next(iterimages) #skip first
        for src, target in iterimages:
            src = (src / 255.0 - 0.5) * 2.0  # Transform into range -1, 1
            target = (target - 0.5) * 2.0  # Transform into range -1, 1
            src =np.array(src).astype(np.float32)[None,:, :, :]
            target = np.array(target).astype(np.float32)[None,:, :, None]
            batch_src = np.concatenate([batch_src,src],axis=0)
            batch_target = np.concatenate([batch_target,target],axis=0)
    return batch_src, batch_target

def save_ndarrays_asimage(filename: str, *arrays: np.ndarray):
    def fix_dimensions(array):
        if array.ndim > 3 or array.ndim < 2: raise ValueError('arrays must have 2 or 3 dimensions')
        if array.ndim == 2:
            array = np.repeat(array[:, :, np.newaxis], 3, axis=2)  # go from blackwhite to rgb to make concat work seamless
        return array

    if len(arrays) > 1:
        arrays = [fix_dimensions(array) for array in arrays]
        arrays = np.concatenate(arrays, axis=1)

    # arrays is just a big 3-dim matrix
    im = Image.fromarray(np.uint8(arrays))
    im.save(filename)

## Helpers

In [77]:
def save(sess, saver, checkpoint_dir, step):
    model_name = "glsgan.model"

    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    saver.save(sess, os.path.join(checkpoint_dir, model_name), global_step=step)


def sample_model(filenames: [str], batch_size: int,epoch: int, sess: tf.Session, source_dir: str, target_dir: str, model: GlsGANModel,
                 sample_dir: str, max_samples:int=3):
    nsamples = min(batch_size, max_samples)

    batch = get_image_batch(batch_size, filenames, source_dir, target_dir)    
    batch_src, batch_target = transform_batch(batch)
   
    sample, d_loss, g_loss = sess.run(
        [model.fake_B, model.d_loss, model.g_loss],
        feed_dict={model.real_data_source: batch_src,
                   model.real_data_target: batch_target}
    )
    
    size = sample.shape[2]
    zdim = batch_src.shape[-1]
    
    sample = tf.squeeze(sample).eval()[:nsamples]
    batch_src = batch_src[:nsamples]
    batch_target = batch_target[:nsamples]
    
    # from shape [nsamples,size,size,1] : tf.tensor to shape [size*nsamples,size] : ndarray
    sample = (np.reshape(sample,[size*nsamples,size]) + 1.0) / 2.0 * 255
    original_source = (tf.reshape(batch_src,[size*nsamples,size,zdim]).eval() + 1.0) / 2.0 * 255
    original_target = (tf.reshape(tf.squeeze(batch_target),[size*nsamples,size]).eval() + 1.0) / 2.0 * 255
    sample_error = np.absolute(original_target - sample)

    save_ndarrays_asimage(os.path.join(sample_dir, 'sample_%d.png' % epoch), original_source, sample,
                          original_target,sample_error)
    print("[Sample] d_loss: {:.8f}, g_loss: {:.8f}".format(d_loss, g_loss))

## Training

In [None]:
tf.reset_default_graph()
source_dir = r'../../../pix2pix-poc-data/training/source'
target_dir = r'../../../pix2pix-poc-data/training/target'
sourcefiles = [f for f in os.listdir(source_dir) if os.path.isfile(os.path.join(source_dir, f))]

now = datetime.datetime.now()
sample_dir = "./samples/sample-%d-%d-%d--%02d%02d" % (
    now.day, now.month, now.year, now.hour, now.minute)  # Generated samples
checkpoint_dir = "./checkpoints"  # Model
if not os.path.exists('./samples'):
    os.mkdir('./samples')
if not os.path.exists(sample_dir):
    os.mkdir(sample_dir)

batch_size = 5
size = IMAGE_SIZE
num_feature_maps = 64
is_grayscale = False
L1_lambda = 100 #using same lambda for ls part a
glsgan_alpha = 0 # ls-gan
iterations = 50000
N = 1

restore_iteration = None
use_random_image_as_sample = True
max_samples = 5


with tf.Session() as sess:
    model = GlsGANModel(batch_size=batch_size, image_size=size, src_color_channels=3, target_color_channels=1,l1_lambda=L1_lambda)

    saver = tf.train.Saver()
    if restore_iteration is not None and restore_iteration > 0:
        iration_start = restore_iteration + 1
        checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
        checkpoint_name = os.path.basename(checkpoint.model_checkpoint_path)
        print('checkpoint_name=' + str(checkpoint_name))
        saver.restore(sess, os.path.join(checkpoint_dir, checkpoint_name))
    else:
        iteration_start = 0
        
    #make glsgan
    #see glsgan paper and https://github.com/guojunq/glsgan/blob/master/glsgan.lua#L257
    def l1diff(x,y):
        dist = tf.reduce_sum(tf.abs(tf.round(y)-tf.round(x)))
        return dist
    pdist = L1_lambda * l1diff(model.real_B,model.fake_B)
    cost1 = pdist + model.d_loss_real - model.d_loss_fake

    glsloss =ops.leaky_relu(cost1,glsgan_alpha)
    #self.d_error_hinge = tf.reduce_mean(self.glsloss)

    # Optimizers
    disc_optim = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5)  # TODO: pix2pix params
    gen_optim = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5)  # TODO: pix2pix params

    disc_update_step = disc_optim.minimize(glsloss, var_list=model.d_vars)
    gen_update_step = gen_optim.minimize(model.g_loss, var_list=model.g_vars)

    sess.run(tf.global_variables_initializer())

    # logging
    writer = tf.summary.FileWriter("./logs", sess.graph)
    g_sum = tf.summary.merge([model.d__sum,
                              model.fake_B_sum, model.d_loss_fake_sum, model.g_loss_sum])
    d_sum = tf.summary.merge([model.d_sum, model.d_loss_real_sum, model.d_loss_sum])

    start_time = time.time()
    for iteration in range(iteration_start, iterations):
        batch = get_image_batch(batch_size, sourcefiles,
                                source_dir=source_dir,
                                target_dir=target_dir)
        batch_src,batch_target = transform_batch(batch)

        #as proposed in glsgan paper update discriminator N time every iteration
        for i in range(0,N):
            _, summary_str = sess.run([disc_update_step, d_sum],
                                      feed_dict={model.real_data_source: batch_src,
                                                 model.real_data_target: batch_target})

            writer.add_summary(summary_str, iteration)

        # Update G network
        _, summary_str = sess.run([gen_update_step, g_sum],
                                  feed_dict={model.real_data_source: batch_src,
                                             model.real_data_target: batch_target})
        writer.add_summary(summary_str, iteration)
        #_, summary_str = sess.run([gen_update_step, g_sum],
                                  #feed_dict={model.real_data_source: batch_src,
                                            # model.real_data_target: batch_target})
        #writer.add_summary(summary_str, iteration)

        errD_fake = model.d_loss_fake.eval(
            {model.real_data_target: batch_target, model.real_data_source: batch_src})
        errD_real = model.d_loss_real.eval(
            {model.real_data_target: batch_target, model.real_data_source: batch_src})
        errG = model.g_loss.eval({model.real_data_target: batch_target, model.real_data_source: batch_src})

        print("Epoch: [%2d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
              % (iteration, time.time() - start_time, errD_fake + errD_real, errG))

        if iteration % 50 == 1:
            # gib nice picture output :)
            sample_model(sourcefiles, batch_size, iteration, sess, source_dir, target_dir, model, sample_dir,
                         max_samples=max_samples)

        if iteration % 500 == 2:
            save(sess, saver, checkpoint_dir, iteration)

Epoch: [ 0] time: 3.6876, d_loss: 1.37057042, g_loss: 91.46552277
Epoch: [ 1] time: 7.0418, d_loss: 1.41155851, g_loss: 85.44960022
[Sample] d_loss: 1.45718765, g_loss: 88.19380951
Epoch: [ 2] time: 11.3903, d_loss: 1.79811370, g_loss: 78.61933899
Epoch: [ 3] time: 19.8668, d_loss: 2.67224097, g_loss: 75.88973236
Epoch: [ 4] time: 23.9828, d_loss: 4.53027916, g_loss: 73.36322021
Epoch: [ 5] time: 27.8095, d_loss: 11.80871677, g_loss: 67.94054413
Epoch: [ 6] time: 30.7866, d_loss: 20.55378151, g_loss: 72.24281311
Epoch: [ 7] time: 34.2821, d_loss: 36.88986969, g_loss: 55.82383728
Epoch: [ 8] time: 37.6691, d_loss: 79.97319031, g_loss: 77.56661987
Epoch: [ 9] time: 41.0316, d_loss: 131.47407532, g_loss: 51.30364227
Epoch: [10] time: 44.1608, d_loss: 248.14453125, g_loss: 45.16580582
Epoch: [11] time: 47.5756, d_loss: 379.81887817, g_loss: 48.67750549
Epoch: [12] time: 51.2167, d_loss: 845.70031738, g_loss: 49.71894455
Epoch: [13] time: 54.7719, d_loss: 953.14160156, g_loss: 43.22514343
E