## Paramaters

In this cell we will define a plethera of paramaters to describe how to run the rest of this notebook including paramaters on model construction, testing, layer size, etc.

In [18]:
input_dir        = "/home/jovyan/local-volume-claim/notebooks/data/landmark_images/"
target_dir       = "/home/jovyan/local-volume-claim/notebooks/data/original_images/"
output_dir       = "./models/face2face-model/"

max_steps        = 50
max_epochs       = 100
batch_size       = 5
channels         = 3

summary_freq     = 50
display_freq     = 100
progress_freq    = 200
save_freq        = 2000
trace_freq       = 0


separable_conv   = True
aspect_ratio     = 1.0
lab_colorization = False
which_direction  = "AtoB" # choices=["AtoB", "BtoA"])
ngf              = 64
NGF              = 64
ndf              = 64
scale_size       = 256
no_flip          = False
flip             = True
lr               = 0.0002
beta1            = 0.5
l1_weight        = 100.0
gan_weight       = 1.0
output_filetype  = "png" # choices=["png", "jpeg"])          
input_filetype   = "jpg"
EPS              = 1e-12
CROP_SIZE        = 256
seed             = 42

## Import dependencies

In this cell we will import all the python modules needed to build a tensorflow model

In [19]:
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '10'

import tensorflow as tf;tf.reset_default_graph()
import numpy as np
import argparse
import os
import json
import glob
import random
import collections
import math
import time
import cv2

## Set random seeds
tf.set_random_seed(seed)
np.random.seed(seed)
random.seed(seed)

## Model containers

In this next cell we will take advantage of pythons `namedtuple` object to store direct properties we will use later

In [20]:
Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, discrim_grads_and_vars, gen_loss_GAN, gen_loss_L1, gen_grads_and_vars, train")

## Defining process functions

In this next cell we will define several utility functions in tensorflow for processing images in a standard way across the model

In [21]:
def preprocess(image):
    with tf.name_scope("preprocess"):
        # [0, 1] => [-1, 1]
        return image * 2 - 1

def deprocess(image):
    with tf.name_scope("deprocess"):
        # [-1, 1] => [0, 1]
        return (image + 1) / 2

def preprocess_lab(lab):
    with tf.name_scope("preprocess_lab"):
        L_chan, a_chan, b_chan = tf.unstack(lab, axis=2)
        # L_chan: black and white with input range [0, 100]
        # a_chan/b_chan: color channels with input range ~[-110, 110], not exact
        # [0, 100] => [-1, 1],  ~[-110, 110] => [-1, 1]
        return [L_chan / 50 - 1, a_chan / 110, b_chan / 110]

def deprocess_lab(L_chan, a_chan, b_chan):
    with tf.name_scope("deprocess_lab"):
        # this is axis=3 instead of axis=2 because we process individual images but deprocess batches
        return tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3)
    
def augment(image, brightness):
    # (a, b) color channels, combine with L channel and convert to rgb
    a_chan, b_chan = tf.unstack(image, axis=3)
    L_chan = tf.squeeze(brightness, axis=3)
    lab = deprocess_lab(L_chan, a_chan, b_chan)
    rgb = lab_to_rgb(lab)
    return rgb

## Convolution layers

Now that we have defined several utility functions, we will now define our convolutional layers

In [22]:
def discrim_conv(batch_input, out_channels, stride):
    padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT")
    return tf.layers.conv2d(padded_input, 
                            out_channels, 
                            kernel_size=4, 
                            strides=(stride, stride), 
                            padding="valid",
                            kernel_initializer=tf.random_normal_initializer(0, 0.02))

def gen_conv(batch_input, out_channels, dtype=tf.float32):
    # generator/encoder_1/separable_conv2d/depthwise_kernel:0
    # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
    initializer = tf.random_normal_initializer(0, 0.02)
    if separable_conv:
        return tf.layers.separable_conv2d(batch_input, 
                                          out_channels, 
                                          kernel_size=4, 
                                          strides=(2, 2), 
                                          padding="same", 
                                          depthwise_initializer=initializer, 
                                          pointwise_initializer=initializer)
    else:
        return tf.layers.conv2d(batch_input, 
                                out_channels, 
                                kernel_size=4, 
                                strides=(2, 2), 
                                padding="same",
                                kernel_initializer=initializer)

def gen_deconv(batch_input, out_channels):
    # [batch, in_height, in_width, in_channels] => [batch, out_height, out_width, out_channels]
    initializer = tf.random_normal_initializer(0, 0.02)
    if separable_conv:
        _b, h, w, _c = batch_input.shape
        resized_input = tf.image.resize_images(batch_input, 
                                               tf.constant([h * 2, w * 2], dtype=tf.int32), 
                                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        return tf.layers.separable_conv2d(resized_input, out_channels, kernel_size=4, strides=(1, 1), padding="same", depthwise_initializer=initializer, pointwise_initializer=initializer)
    else:
        return tf.layers.conv2d_transpose(batch_input, out_channels, kernel_size=4, strides=(2, 2), padding="same", kernel_initializer=initializer)

## Utility functions

Now that each of our main layer functions have been constructed we will now define some handy utility functions which will help us add unique utility to our model to increase it's overall efficency and specialization compared to what `tensorflow` gives us out of the box

In [23]:
def lrelu(x, a):
    with tf.name_scope("lrelu"):
        # adding these together creates the leak part and linear part
        # then cancels them out by subtracting/adding an absolute value term
        # leak: a*x/2 - a*abs(x)/2
        # linear: x/2 + abs(x)/2

        # this block looks like it has 2 inputs on the graph unless we do this
        x = tf.identity(x)
        return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x)

def batchnorm(inputs):
    return tf.layers.batch_normalization(inputs, axis=3, epsilon=1e-5, momentum=0.1, training=True, gamma_initializer=tf.random_normal_initializer(1.0, 0.02))

def check_image(image):
    assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels")
    with tf.control_dependencies([assertion]):
        image = tf.identity(image)

    if image.get_shape().ndims not in (3, 4):
        raise ValueError("image must be either 3 or 4 dimensions")

    # make the last dimension 3 so that you can unstack the colors
    shape = list(image.get_shape())
    shape[-1] = 3
    image.set_shape(shape)
    return image

# based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c
def rgb_to_lab(srgb):
    with tf.name_scope("rgb_to_lab"):
        srgb = check_image(srgb)
        srgb_pixels = tf.reshape(srgb, [-1, 3])

        with tf.name_scope("srgb_to_xyz"):
            linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
            exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
            rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
            rgb_to_xyz = tf.constant([
                #    X        Y          Z
                [0.412453, 0.212671, 0.019334], # R
                [0.357580, 0.715160, 0.119193], # G
                [0.180423, 0.072169, 0.950227], # B
            ])
            xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)

        # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
        with tf.name_scope("xyz_to_cielab"):
            # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)

            # normalize for D65 white point
            xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])

            epsilon = 6/29
            linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
            exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
            fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask

            # convert to lab
            fxfyfz_to_lab = tf.constant([
                #  l       a       b
                [  0.0,  500.0,    0.0], # fx
                [116.0, -500.0,  200.0], # fy
                [  0.0,    0.0, -200.0], # fz
            ])
            lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])

        return tf.reshape(lab_pixels, tf.shape(srgb))


def lab_to_rgb(lab):
    with tf.name_scope("lab_to_rgb"):
        lab = check_image(lab)
        lab_pixels = tf.reshape(lab, [-1, 3])

        # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
        with tf.name_scope("cielab_to_xyz"):
            # convert to fxfyfz
            lab_to_fxfyfz = tf.constant([
                #   fx      fy        fz
                [1/116.0, 1/116.0,  1/116.0], # l
                [1/500.0,     0.0,      0.0], # a
                [    0.0,     0.0, -1/200.0], # b
            ])
            fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)

            # convert to xyz
            epsilon = 6/29
            linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
            exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
            xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask

            # denormalize for D65 white point
            xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])

        with tf.name_scope("xyz_to_srgb"):
            xyz_to_rgb = tf.constant([
                #     r           g          b
                [ 3.2404542, -0.9692660,  0.0556434], # x
                [-1.5371385,  1.8760108, -0.2040259], # y
                [-0.4985314,  0.0415560,  1.0572252], # z
            ])
            rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
            # avoid a slightly negative number messing up the conversion
            rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
            linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
            exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
            srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask

        return tf.reshape(srgb_pixels, tf.shape(lab))

## Loading our data

Now that alot of these utility functions have been defined. We will make use of them when we load data from our input directory. To do this we will create a simple `load_examples` function which will look in out specifies input directory and return the named tuple container we defined earlier as our container.

These data format we will be using is images of `.jpg` or `.png` variety. 

In [24]:
seed = random.randint(0, 2**31 - 1)
def transform(image):
    r = image

    # area produces a nice downscaling, but does nearest neighbor for upscaling
    # assume we're going to be doing downscaling here
    r = tf.image.resize_images(r, [scale_size, scale_size], method=tf.image.ResizeMethod.AREA)
    return r

def create_image_iterator(dirs):

    ## Identify data type
    if input_filetype == "png":
        decode = tf.image.decode_png
    elif input_filetype == "jpg":
        decode = tf.image.decode_jpeg
    else:
        raise Exception("Unknown input filetype")

    def create_image_dataset(filename_dataset):
        ## Convert data via mapping
        with tf.name_scope("decode"):
            image_dataset = filename_dataset.map(lambda x:
                                                 decode(tf.read_file(x), channels=channels))
        with tf.name_scope("convert_preprocess"):
            image_dataset = image_dataset.map(lambda x:
                                              preprocess(tf.image.convert_image_dtype(x,dtype=tf.float32)))
        with tf.name_scope("transform"):
            image_dataset = image_dataset.map(lambda x:
                                              transform(x))

        return image_dataset
    
    glob_files = [glob.glob(os.path.join(d,"**/**/*.%s"%input_filetype)) for d in dirs]
    sorted_glob_files = [sorted(d, key=os.path.basename) for d in glob_files]
    assert len(set([len(i) for i in sorted_glob_files])) > 0, "All dataset dirs must be the same size"
    
    filename_datasets = [tf.data.Dataset.from_tensor_slices(d) for d in sorted_glob_files]
    image_datasets = [create_image_dataset(fd) for fd in filename_datasets]
    image_datasets = tf.data.Dataset.zip(tuple(id for id in image_datasets))

    ## Create repetitions and batches
    image_datasets = image_datasets.repeat(max_epochs).batch(batch_size=batch_size)
    
    ## Create iterator form datasets
    iterator = image_datasets.make_one_shot_iterator()
    
    ## Create dataset size from files
    dataset_size = len(sorted_glob_files[0])

    return dataset_size, image_datasets, iterator

## Creating the generator

In this next cell we will create the `G` in `GAN`. This model consists of an autoencoder of which we will return the final layer

In [25]:
def create_generator(generator_inputs, generator_outputs_channels, ngf=64):
    layers = []

    # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf]
    with tf.variable_scope("encoder_1"):
        output = gen_conv(generator_inputs, ngf)
        layers.append(output)

    layer_specs = [
        ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2]
        ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4]
        ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8]
        ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8]
        ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8]
        ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8]
        ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8]
    ]

    for out_channels in layer_specs:
        with tf.variable_scope("encoder_%d" % (len(layers) + 1)):
            rectified = lrelu(layers[-1], 0.2)
            # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels]
            convolved = gen_conv(rectified, out_channels)
            output = batchnorm(convolved)
            layers.append(output)

    layer_specs = [
        (ngf * 8, 0.5),   # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2]
        (ngf * 8, 0.5),   # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2]
        (ngf * 8, 0.5),   # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2]
        (ngf * 8, 0.0),   # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2]
        (ngf * 4, 0.0),   # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2]
        (ngf * 2, 0.0),   # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2]
        (ngf, 0.0),       # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2]
    ]

    # generator/encoder_1/separable_conv2d/depthwise_kernel:0
    num_encoder_layers = len(layers)
    for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
        skip_layer = num_encoder_layers - decoder_layer - 1
        with tf.variable_scope("decoder_%d" % (skip_layer + 1)):
            if decoder_layer == 0:
                # first decoder layer doesn't have skip connections
                # since it is directly connected to the skip_layer
                input = layers[-1]
            else:
                input = tf.concat([layers[-1], layers[skip_layer]], axis=3)

            rectified = tf.nn.relu(input)
            # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels]
            output = gen_deconv(rectified, out_channels)
            output = batchnorm(output)

            if dropout > 0.0:
                output = tf.nn.dropout(output, keep_prob=1 - dropout)

            layers.append(output)

    # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels]
    with tf.variable_scope("decoder_1"):
        input = tf.concat([layers[-1], layers[0]], axis=3)
        rectified = tf.nn.relu(input)
        output = gen_deconv(rectified, generator_outputs_channels)
        output = tf.tanh(output)
        layers.append(output)

    return layers[-1]

## Creating the discriminator

In the next cell we will now create the discriminator

In [26]:
def create_discriminator(discrim_inputs, discrim_targets):
    n_layers = 3
    layers = []
    
    # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2]
    input = tf.concat([discrim_inputs, discrim_targets], axis=3)

    # layer_1: [batch, 256, 256, in_channels * 2] => [batch, 128, 128, ndf]
    with tf.variable_scope("layer_1"):
        convolved = discrim_conv(input, ndf, stride=2)
        rectified = lrelu(convolved, 0.2)
        layers.append(rectified)

    # layer_2: [batch, 128, 128, ndf] => [batch, 64, 64, ndf * 2]
    # layer_3: [batch, 64, 64, ndf * 2] => [batch, 32, 32, ndf * 4]
    # layer_4: [batch, 32, 32, ndf * 4] => [batch, 31, 31, ndf * 8]
    for i in range(n_layers):
        with tf.variable_scope("layer_%d" % (len(layers) + 1)):
            out_channels = ndf * min(2**(i+1), 8)
            stride = 1 if i == n_layers - 1 else 2  # last layer here has stride 1
            convolved = discrim_conv(layers[-1], out_channels, stride=stride)
            normalized = batchnorm(convolved)
            rectified = lrelu(normalized, 0.2)
            layers.append(rectified)

    # layer_5: [batch, 31, 31, ndf * 8] => [batch, 30, 30, 1]
    with tf.variable_scope("layer_%d" % (len(layers) + 1)):
        convolved = discrim_conv(rectified, out_channels=1, stride=1)
        output = tf.sigmoid(convolved)
        layers.append(output)

    return layers[-1]

## Creating the model

now we will combine the creation of the generator and discriminator to create the GAN we will be training

In [27]:
def create_model(inputs, targets):
    with tf.variable_scope("generator"):
        out_channels = channels # int(targets.get_shape()[-1])
        outputs = create_generator(inputs, out_channels)

    # create two copies of discriminator, one for real pairs and one for fake pairs
    # they share the same underlying variables
    with tf.name_scope("real_discriminator"):
        with tf.variable_scope("discriminator"):
            # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
            predict_real = create_discriminator(inputs, targets)

    with tf.name_scope("fake_discriminator"):
        with tf.variable_scope("discriminator", reuse=True):
            # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
            predict_fake = create_discriminator(inputs, outputs)

    with tf.name_scope("discriminator_loss"):
        # minimizing -tf.log will try to get inputs to 1
        # predict_real => 1
        # predict_fake => 0
        discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))

    with tf.name_scope("generator_loss"):
        # predict_fake => 1
        # abs(targets - outputs) => 0
        gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
        gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs))
        gen_loss = gen_loss_GAN * gan_weight + gen_loss_L1 * l1_weight

    with tf.name_scope("discriminator_train"):
        discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")]
        discrim_optim = tf.train.AdamOptimizer(lr, beta1)
        discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
        discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)

    with tf.name_scope("generator_train"):
        with tf.control_dependencies([discrim_train]):
            gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")]
            gen_optim = tf.train.AdamOptimizer(lr, beta1)
            gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
            gen_train = gen_optim.apply_gradients(gen_grads_and_vars)

    ema = tf.train.ExponentialMovingAverage(decay=0.99)
    update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1])

    global_step = tf.train.get_or_create_global_step()
    incr_global_step = tf.assign(global_step, global_step+1)

    return Model(
        predict_real=predict_real,
        predict_fake=predict_fake,
        discrim_loss=ema.average(discrim_loss),
        discrim_grads_and_vars=discrim_grads_and_vars,
        gen_loss_GAN=ema.average(gen_loss_GAN),
        gen_loss_L1=ema.average(gen_loss_L1),
        gen_grads_and_vars=gen_grads_and_vars,
        outputs=outputs,
        train=tf.group(update_losses, incr_global_step, gen_train),
    )

## Saving model image output

When we are training we would like to see how our model is performing via visual and visceral feedback

In [28]:
def save_images(fetches, step=None):
    
    ## create image directory
    image_dir = os.path.join(output_dir, "images")
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)

    filesets = []
    for i, (kind, content) in enumerate(fetches["display"].items()):
        
        fileset = {"name": kind, "step": step}

        filename = kind + ".png"
        if step is not None:
            filename = "%08d-%s" % (step, filename)
        
        fileset[kind] = filename
        out_path = os.path.join(image_dir, filename)
        contents = fetches["display"][kind]
        filesets.append(fileset)
        
    return filesets


def append_index(filesets, step=False):
    index_path = os.path.join(output_dir, "index.html")
    if os.path.exists(index_path):
        index = open(index_path, "a")
    else:
        index = open(index_path, "w")
        index.write("<html><body><table><tr>")
        if step:
            index.write("<th>step</th>")
        index.write("<th>name</th><th>input</th><th>output</th><th>target</th></tr>")

    for fileset in filesets:
        index.write("<tr>")

        if step:
            index.write("<td>%d</td>" % fileset["step"])
        index.write("<td>%s</td>" % fileset["name"])
        index.write("<td><img src='images/%s'></td>" % fileset["name"])
        index.write("</tr>")
    return index_path

## Training setup and evaluation

We now have all the needed functions to define our model, we will now start to execute it's functionality based on the mode we have defined in our paramaters

## Load data for model

In this next cell we will use the load_examples function to load our training data into local memory. We will then print some high level properties about the data

## Create model

In this next cell we will directly create our model from the examples inputs and targets

In [29]:
from tensorflow.python.client import device_lib

gpus = [x.name for x in device_lib.list_local_devices() if x.device_type == 'GPU']
assert len(gpus) > 0, "GPUS should be available :("

tf.reset_default_graph()

with tf.device('/device:CPU:0'):
    dataset_size, global_image_dataset, global_image_iterator = create_image_iterator(dirs = [input_dir, target_dir])
    input_image_next, target_image_next = global_image_iterator.get_next()

with tf.device('/device:GPU:0'):
    model = create_model(input_image_next, target_image_next)

In [30]:
inputs = deprocess(input_image_next)
targets = deprocess(target_image_next)
outputs = deprocess(model.outputs)

## Model visibility

In order to see the output in a proper format that is viewable on most systems, we define a convert function to apply to the input, output, and target images. We then define a display_fetches output that will encode all of the input, output, and targets

In [31]:
def convert(image):
    if aspect_ratio != 1.0:
        # upscale to correct aspect ratio
        size = [CROP_SIZE, int(round(CROP_SIZE * aspect_ratio))]
        image = tf.image.resize_images(image, size=size, method=tf.image.ResizeMethod.BICUBIC)

    return tf.image.convert_image_dtype(image, dtype=tf.uint8, saturate=True)

# reverse any processing on images so they can be written to disk or displayed to user
with tf.name_scope("convert_inputs"):
    converted_inputs = convert(inputs)

with tf.name_scope("convert_targets"):
    converted_targets = convert(targets)

with tf.name_scope("convert_outputs"):
    converted_outputs = convert(outputs)

with tf.name_scope("encode_images"):
    display_fetches = {
        "inputs": tf.map_fn(tf.image.encode_png, converted_inputs, dtype=tf.string, name="input_pngs"),
        "targets": tf.map_fn(tf.image.encode_png, converted_targets, dtype=tf.string, name="target_pngs"),
        "outputs": tf.map_fn(tf.image.encode_png, converted_outputs, dtype=tf.string, name="output_pngs"),
    }

## Summaries

In the next cell we will define multiple summaries to provide insight into the model. Including image outputs and losses

In [32]:
# summaries
with tf.name_scope("inputs_summary"):
    tf.summary.image("inputs", converted_inputs)

with tf.name_scope("targets_summary"):
    tf.summary.image("targets", converted_targets)

with tf.name_scope("outputs_summary"):
    tf.summary.image("outputs", converted_outputs)

with tf.name_scope("predict_real_summary"):
    tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8))

with tf.name_scope("predict_fake_summary"):
    tf.summary.image("predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8))

tf.summary.scalar("discriminator_loss", model.discrim_loss)
tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN)
tf.summary.scalar("generator_loss_L1", model.gen_loss_L1)

for var in tf.trainable_variables():
    tf.summary.histogram(var.op.name + "/values", var)

for grad, var in model.discrim_grads_and_vars + model.gen_grads_and_vars:
    tf.summary.histogram(var.op.name + "/gradients", grad)
    
with tf.name_scope("parameter_count"):
    parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
    
summary_merged = tf.summary.merge_all()

## Training

In the next cell we will train our model and save our intermediate checkpoints to a file and output log summaries to out output directory

In [None]:
options = None
run_metadata = None

## Create tf session configuration
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.log_device_placement = True
config.allow_soft_placement=True

def get_session(sess):
    session = sess
    while type(session).__name__ != 'Session':
        session = session._sess
    return session

saver = tf.train.Saver(max_to_keep=5)

with tf.train.MonitoredTrainingSession(config=config) as sess:
    
    ## Unfinalize and print paramater count
    sess.graph._unsafe_unfinalize()
    print("parameter_count =", sess.run(parameter_count))
    
    ## Load the latest checkpoint
    ## if it exists
    if output_dir is not None:
        checkpoint_restore = tf.train.latest_checkpoint(output_dir)
        if checkpoint_restore is not None:
            saver.restore(sess, checkpoint_restore)  
            print ("Restoring, ", checkpoint_restore)
        else:
            print ("No model to load ...")
            
    ## Create summary writer
    summary_writer = tf.summary.FileWriter(output_dir,sess.graph)

    ## start training :)
    print ("Training has started! ...")
    start = time.time()
    
    global_steps = 1
    for epoch in range(max_epochs):
        for datum in range(0, dataset_size, batch_size):

            try:
                input_image = sess.run(input_image_next)
                target_image = sess.run(target_image_next)
            except:
                print ("Error in loading data :(((")
                continue

            for step in range(max_steps):        

                fetches = {
                    "train": model.train,
                }
                
                if global_steps % progress_freq == 0:
                    fetches["discrim_loss"] = model.discrim_loss
                    fetches["gen_loss_GAN"] = model.gen_loss_GAN
                    fetches["gen_loss_L1"] = model.gen_loss_L1

                if global_steps % summary_freq == 0:
                    fetches["summary"] = summary_merged

                if global_steps % display_freq == 0:                    
                    fetches["display"] = display_fetches

                try:
                    results = sess.run(fetches, options=options, run_metadata=run_metadata)
                    results["global_step"] = global_steps
                except tf.errors.InvalidArgumentError as e:
                    print (e)
                    print ("Error in training :(")
                    continue
                    
                if global_steps % summary_freq == 0:
                    print("recording summary")
                    summary_writer.add_summary(results["summary"], results["global_step"])

                if global_steps % display_freq == 0:
                    print("saving display images")
                    display_images = results
                    filesets = save_images(results, 
                                           step=results["global_step"])
                    append_index(filesets, step=True)

                if global_steps % progress_freq == 0:
                    
                    steps_per_epoch = int(math.ceil(dataset_size / batch_size))
                    train_epoch = math.ceil(results["global_step"] / steps_per_epoch)
                    train_step = (results["global_step"] - 1) % steps_per_epoch + 1
                    rate = (step + 1) * batch_size / (time.time() - start)
                    remaining = (max_steps - step) * batch_size / rate
                    print("progress  epoch %d  step %d  datum %d global_step %d image/sec %0.1f  remaining %dm" % (train_epoch, train_step, datum, global_steps, rate, remaining / 60))
                    print("discrim_loss", results["discrim_loss"])
                    print("gen_loss_GAN", results["gen_loss_GAN"])
                    print("gen_loss_L1", results["gen_loss_L1"])
                    print("-------------------------")

                if global_steps % save_freq == 0:
                    print("saving model")
                    
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    
                    saver.save(get_session(sess),
                               os.path.join(output_dir, "model"),
                               global_step=global_steps)
                    
                global_steps += 1

INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
parameter_count = 6314292
No model to load ...
Training has started! ...
recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 1  step 200  datum 15 global_step 200 image/sec 1.1  remaining 0m
discrim_loss 0.8497088
gen_loss_GAN 1.3617435
gen_loss_L1 0.0956818
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 1  step 400  datum 35 global_step 400 image/sec 0.6  remaining 0m
discrim_loss 0.818709
gen_loss_GAN 1.514611
gen_loss_L1 0.095655695
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 2  step 196  datum 55 global_step 600 image/sec 0.4  remaining 0m
discrim_loss 0.7857881
ge

recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 14  step 348  datum 555 global_step 5600 image/sec 0.0  remaining 1m
discrim_loss 0.45811945
gen_loss_GAN 2.7557523
gen_loss_L1 0.06770127
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 15  step 144  datum 575 global_step 5800 image/sec 0.0  remaining 1m
discrim_loss 0.36797342
gen_loss_GAN 3.0643728
gen_loss_L1 0.07570643
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 15  step 344  datum 595 global_step 6000 image/sec 0.0  remaining 2m
discrim_loss 0.67117673
gen_loss_GAN 2.8199196
gen_loss_L1 0.06577058
-------------------------
saving model
recording summary
recording summary
saving display images
recording summary
recording summary
saving display image

recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 28  step 292  datum 1115 global_step 11200 image/sec 0.0  remaining 4m
discrim_loss 0.62537676
gen_loss_GAN 3.8298366
gen_loss_L1 0.07739013
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 29  step 88  datum 1135 global_step 11400 image/sec 0.0  remaining 4m
discrim_loss 0.28715926
gen_loss_GAN 3.8929365
gen_loss_L1 0.07423585
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 29  step 288  datum 1155 global_step 11600 image/sec 0.0  remaining 4m
discrim_loss 0.19929133
gen_loss_GAN 4.1187253
gen_loss_L1 0.0772158
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving disp

recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 42  step 236  datum 1675 global_step 16800 image/sec 0.0  remaining 6m
discrim_loss 0.26509944
gen_loss_GAN 4.2765903
gen_loss_L1 0.07971326
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 43  step 32  datum 1695 global_step 17000 image/sec 0.0  remaining 6m
discrim_loss 0.17849621
gen_loss_GAN 4.565488
gen_loss_L1 0.077225305
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 43  step 232  datum 1715 global_step 17200 image/sec 0.0  remaining 6m
discrim_loss 0.22860649
gen_loss_GAN 4.3927093
gen_loss_L1 0.08040474
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving dis

recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 56  step 180  datum 215 global_step 22400 image/sec 0.0  remaining 9m
discrim_loss 0.18201655
gen_loss_GAN 4.448633
gen_loss_L1 0.07371007
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 56  step 380  datum 235 global_step 22600 image/sec 0.0  remaining 9m
discrim_loss 0.12543511
gen_loss_GAN 4.819841
gen_loss_L1 0.07846063
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving display images
progress  epoch 57  step 176  datum 255 global_step 22800 image/sec 0.0  remaining 9m
discrim_loss 0.074490055
gen_loss_GAN 5.325014
gen_loss_L1 0.0770509
-------------------------
recording summary
recording summary
saving display images
recording summary
recording summary
saving display 

In [None]:
import io
i = np.array(Image.open(io.BytesIO(display_images["display"]["inputs"][0])))
j = np.array(Image.open(io.BytesIO(display_images["display"]["targets"][0])))
o = np.array(Image.open(io.BytesIO(display_images["display"]["outputs"][0])))
Image.fromarray(np.hstack((i,j,o)))