In [None]:
__author__ = 'Fan Fan, Kwonjoon Lee and Weijian Xu'

# Python libraries.
import os
import tensorflow as tf
import numpy as np
import scipy.misc
from datetime import datetime

# Libraries related to tflib.
import functools
import tflib as lib
import tflib.ops.linear
import tflib.ops.conv2d
import tflib.ops.batchnorm
import tflib.ops.deconv2d
import tflib.save_images
import tflib.small_imagenet
import tflib.ops.layernorm
import tflib.plot

# Custom libraries.
from Utils import *

In [None]:
# Defined hyper-parameters.

# Number of units in ResNet structure. 5 for ResNet-32.
units      = 5 
# Mode of training: 0 for baseline, 1 for WINN.
mode       = 0
# GPU index. Default value is 0.
gpu        = 1
# Number of epochs. Default value is 200.
epochs     = 200
# Batch size. Default value is 100.
batch_size = 100
# Number of categories. For MNIST, it is 10.
cats       = 10
# Number of critics. Default value is 1.
critics    = 1
# Max number of optimizing steps in synthesis.
max_opt_steps = 2000
# Image shape. For MNIST, it is [28, 28, 1].
image_shape   = [28, 28, 1]

# Exported hyper-parameters.
half_batch_size = batch_size // 2
height, width, channels = image_shape

In [None]:
# Set the root dir of data and logs.
if   mode == 0:
    root_dir = './baseline'
elif mode == 1:
    root_dir = './winn'

In [None]:
def layer_norm(scope, input_layer, is_training, reuse):
    output_layer = tf.contrib.layers.layer_norm(
        input_layer,
        scale = True,
        reuse = reuse,
        scope = scope
    )
    return output_layer

def conv2d_res(scope, input_layer, output_dim, use_bias=False,
               filter_size=3, strides=[1, 1, 1, 1]):
    
    input_dim = input_layer.get_shape().as_list()[-1]

    with tf.variable_scope(scope):
        conv_filter = tf.get_variable(
            'conv_weight',
            shape = [filter_size, filter_size, input_dim, output_dim],
            dtype = tf.float32,
            initializer = tf.contrib.layers.variance_scaling_initializer(),
            regularizer = tf.contrib.layers.l2_regularizer(scale = 0.0002)
        )
        conv = tf.nn.conv2d(input_layer, conv_filter, strides, 'SAME')

        if use_bias:
            bias = tf.get_variable(
                'conv_bias',
                shape = [output_dim],
                dtype = tf.float32,
                initializer = tf.constant_initializer(0.0)
            )

            output_layer = tf.nn.bias_add(conv, bias)
            output_layer = tf.reshape(output_layer, conv.get_shape())
        else:
            output_layer = conv

        return output_layer

def residual(scope, input_layer, is_training, reuse, 
             increase_dim=False, first=False):
    
    input_dim = input_layer.get_shape().as_list()[-1]

    if increase_dim:
        output_dim = input_dim * 2
        strides = [1, 2, 2, 1]
    else:
        output_dim = input_dim
        strides = [1, 1, 1, 1]

    with tf.variable_scope(scope):
        if first:
            h0    = input_layer
        else:
            h0_ln = layer_norm('h0_ln', input_layer, is_training, reuse)
            h0    = swish(h0_ln)

        h1_conv = conv2d_res('h1_conv', h0, output_dim, strides=strides)
        h1_ln   = layer_norm('h1_ln', h1_conv, is_training, reuse)
        h1      = swish(h1_ln)

        h2_conv = conv2d_res('h2_conv', h1, output_dim)
        if increase_dim:
            l = avg_pool('l_pool', input_layer)
            l = tf.pad(l, [[0, 0], [0, 0], 
                           [0, 0], [input_dim // 2, input_dim // 2]])
        else:
            l = input_layer
        h2 = tf.add(h2_conv, l)

        return h2
    
def network(images, is_training, reuse):
    with tf.variable_scope('layers', reuse=reuse):
        init_dim   = 16
        batch_size = images.get_shape().as_list()[0]

        r0_conv = conv2d_res('r0_conv', images, init_dim)
        r0_ln   = layer_norm('r0_bn', r0_conv, is_training, reuse)
        r0      = swish(r0_ln)

        r1_res=residual('r1.0', r0, is_training, reuse, first=True)
        for k in xrange(1, units):
            r1_res = residual('res1.{}'.format(k), r1_res, is_training, reuse)

        r2_res=residual('r2.0', r1_res, is_training, reuse, increase_dim=True)
        for k in xrange(1, units):
            r2_res = residual('res2.{}'.format(k), r2_res, is_training, reuse)

        r3_res=residual('r3.0', r2_res, is_training, reuse, increase_dim=True)
        for k in xrange(1, units):
            r3_res = residual('r3.{}'.format(k), r3_res, is_training, reuse)

        r4_bn = layer_norm('r4_ln', r3_res, is_training, reuse)
        r4 = swish(r4_bn)

        r5 = tf.reduce_mean(r4, axis = [1, 2])

        fc = fully_connected('fc', tf.reshape(r5, [batch_size, -1]), 10)
        wass = linear(tf.reshape(fc, [batch_size, -1]), 1, 'wass')
        return tf.nn.softmax(fc), fc, wass

In [None]:
def build_train_op():
    # Placeholders.
    batch_shape = [batch_size, height, width, channels]
    d_real_images_place = tf.placeholder(
        tf.float32,
        shape = batch_shape,
        name = 'd_real_images_place'
    )
    d_fake_images_place = tf.placeholder(
        tf.float32,
        shape = batch_shape,
        name = 'd_fake_images_place'
    )
    d_real_labels_place = tf.placeholder(
        tf.int32,
        shape = [batch_size, ],
        name = 'd_real_labels_place'
    )
    d_fake_labels_place = tf.placeholder(
        tf.int32,
        shape = [batch_size, ],
        name = 'd_fake_labels_place'        
    )

    # Build network.
    d_real_probs, d_real_logits, d_real_wass = \
        network(d_real_images_place, is_training = True,  reuse = False)
    d_fake_probs, d_fake_logits, d_fake_wass = \
        network(d_fake_images_place, is_training = False, reuse = True)
    
    # Loss term 1: Softmax cross-entropy loss.
    d_softmax_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels = d_real_labels_place,
        logits = d_real_logits
    )
    d_softmax_loss = tf.reduce_mean(d_softmax_losses)

    # Loss term 2: Wasserstein loss with gradient penalty.
    # 2.1: Neg. Wasserstein distance. 
    d_real_loss     = tf.reduce_mean(d_real_wass)
    d_fake_loss     = tf.reduce_mean(d_fake_wass)
    d_neg_wass_dist = d_fake_loss - d_real_loss 
    # 2.2: Gradient penalty on interpolated images.
    eps = tf.random_uniform([batch_size, 1, 1, 1], 0.0, 1.0)
    eps = eps + tf.zeros(d_real_images_place.shape, dtype=eps.dtype)
    d_inter_images = eps * d_real_images_place + (1-eps) * d_fake_images_place
    d_inter_prob, d_inter_logits, d_inter_wass = \
        network(d_inter_images, is_training = False, reuse = True)
    d_inter_grad      = tf.gradients(d_inter_wass, d_inter_images)[0]
    d_inter_grad_norm = tf.sqrt(tf.reduce_sum(tf.square(d_inter_grad), 
                                              axis=[1, 2, 3]))
    d_inter_grad_penalty = tf.reduce_mean(
                           tf.square(d_inter_grad_norm - 1.0))
    # 2.3: Wasserstein loss with gradient penalty.
    scale = 10.0
    d_wass_loss = d_neg_wass_dist + scale * d_inter_grad_penalty
    
    # Build loss for baseline or WINN using loss terms.
    if   mode == 0:
        d_loss = d_softmax_loss
    elif mode == 1:
        d_loss = d_softmax_loss + 0.01 * d_wass_loss

    # Optimizer.
    d_trainable_vars = [x for x in tf.trainable_variables() \
                          if 'layers' in x.name]
    d_optimizer = tf.train.AdamOptimizer(0.001, beta1 = 0.0, beta2 = 0.9)
    d_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(d_update_ops):
        d_train_op = d_optimizer.minimize(d_loss, var_list = d_trainable_vars)
    
    return [d_real_images_place, d_fake_images_place, 
            d_real_labels_place, d_fake_labels_place, 
            d_loss, d_softmax_loss, d_wass_loss, d_real_loss, d_fake_loss, 
            d_train_op]

def build_synthesis_op():
    # Placeholders and variables.
    batch_shape = [batch_size, height, width, channels]
    g_images = tf.Variable(
        np.random.uniform(low = -1.0, 
                          high = 1.0, 
                          size = batch_shape
        ).astype('float32'), 
        name='g_images'        
    )
    g_thres_place = tf.placeholder(
        tf.float32,
        shape = [],
        name = 'g_thres_place'
    )        
    g_images_place = tf.placeholder(
        tf.float32,
        shape = g_images.get_shape(),
        name = 'g_images_place'
    )
    g_labels_place = tf.placeholder(
        tf.int32,
        shape = [batch_size, ],
        name = 'g_labels_place'
    )
    g_images_op = g_images.assign(g_images_place)
    
    # Build network.
    g_probs, g_logits, g_wass=network(g_images, is_training=False, reuse=True)
    
    # Loss term 1: Softmax cross-entropy loss.
    g_softmax_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels = g_labels_place,
        logits = g_logits
    )
    g_softmax_loss = tf.reduce_mean(g_softmax_losses)    
    
    # Loss term 2: (Absolute) Wasserstein loss.
    g_fake_loss      = tf.reduce_mean(g_wass)
    g_abs_wass_dist  = tf.abs(g_fake_loss - g_thres_place)
    
    # Build loss using loss terms.
    g_loss           = g_abs_wass_dist + 0.5 * g_softmax_loss
    
    # Optimizer.
    g_trainable_vars = [x for x in tf.trainable_variables() 
                        if 'g_images' in x.name]
    g_optimizer    = tf.train.AdamOptimizer(0.02, beta2 = 0.9)
    g_synthesis_op = g_optimizer.minimize(g_loss, var_list = g_trainable_vars)
    
    return [g_images_place, g_labels_place, g_thres_place, 
            g_images, g_images_op, g_loss, g_softmax_loss, 
            g_synthesis_op, g_abs_wass_dist]
    
def build_test_op():
    # Placeholders.
    batch_shape = [batch_size, height, width, channels]
    c_images_place = tf.placeholder(
        tf.float32,
        shape = batch_shape,
        name = 'c_images_place'
    )
    c_labels_place = tf.placeholder(
        tf.int32,
        shape = [batch_size, ],
        name = 'c_labels_place'
    )
    c_probs, c_logits, c_wass = \
        network(c_images_place, is_training = False, reuse = True)
    
    # Build loss.
    c_softmax_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels = c_labels_place,
        logits = c_logits
    )
    c_loss = tf.reduce_mean(c_softmax_losses)
    
    # Prediction.
    c_preds = tf.equal(
        tf.cast(tf.argmax(c_probs, axis = 1), tf.int32),
        c_labels_place
    )
    
    # Accuracy.
    c_acc = tf.reduce_mean(tf.cast(c_preds, tf.float32))
    
    return [c_images_place, c_labels_place, c_softmax_losses, 
            c_loss, c_preds, c_acc]

In [None]:
def dcgan_normalize(name, inputs):
    return lib.ops.layernorm.Layernorm(name, [1, 2, 3], inputs)

def dcgan_generator(n_samples, noise = None, dim = 64, ln = True):
    # Set std for weight initialization in tflib.
    lib.ops.conv2d.set_weights_stdev(0.1)
    lib.ops.deconv2d.set_weights_stdev(0.1)
    lib.ops.linear.set_weights_stdev(0.1)
    
    # Resize method.
    method = tf.image.ResizeMethod.NEAREST_NEIGHBOR 
    
    # DCGAN initialization generator.
    with tf.variable_scope("dcgan_gen", reuse = False):
        
        # External noise or internal noise.
        if noise is None: 
            noise = tf.random_normal([n_samples, 128])

        # Linear layer.
        x = lib.ops.linear.Linear('dcgan_linear', 128, 4*4*8*dim, noise)
        x = tf.reshape(x, [-1, 8*dim, 4, 4])
        
        # Norm + Conv2D + Resize 1.
        if ln: 
            x = dcgan_normalize('dcgan_ln1',  x)
        x = lib.ops.conv2d.Conv2D('dcgan_conv1', 8*dim, 4*dim, 5, x, stride=1)
        x = tf.transpose(x, [0, 2, 3, 1])
        x = tf.image.resize_images(x, [8, 8], method = method)
        x = tf.transpose(x, [0, 3, 1, 2])
        
        # Norm + Conv2D + Resize 2.
        if ln:
            x = dcgan_normalize('dcgan_ln2', x)
        x = lib.ops.conv2d.Conv2D('dcgan_conv2', 4*dim, 2*dim, 5, x, stride=1)
        x = tf.transpose(x, [0, 2, 3, 1])
        x = tf.image.resize_images(x, [16, 16], method = method)
        x = tf.transpose(x, [0, 3, 1, 2])

        # Norm + Conv2D + Resize 3.
        if ln: 
            x = dcgan_normalize('dcgan_ln3', x)
        x = lib.ops.conv2d.Conv2D('dcgan_conv3', 2*dim,   dim, 5, x, stride=1)
        x = tf.transpose(x, [0, 2, 3, 1])
        x = tf.image.resize_images(x, [28, 28], method = method)
        x = tf.transpose(x, [0, 3, 1, 2])

        # Norm + Conv2D + Resize 4.
        if ln: 
            x = dcgan_normalize('dcgan_ln4', x)
        x = lib.ops.conv2d.Conv2D('dcgan_conv4',   dim,     1, 5, x, stride=1)
        x = tf.transpose(x, [0, 2, 3, 1])
        x = tf.image.resize_images(x, [28, 28], method = method)
        x = tf.transpose(x, [0, 3, 1, 2])
        
        # Tanh non-linearity.
        x = tf.tanh(x)
        x = tf.transpose(x, [0, 2, 3, 1])
        
    # Reset std for weight initialization in tflib.
    lib.ops.conv2d.unset_weights_stdev()
    lib.ops.deconv2d.unset_weights_stdev()
    lib.ops.linear.unset_weights_stdev()

    return x

def build_dcgan_init_op():
    # Variable, placeholder and assign operator for multiple generated images.
    i_features = tf.Variable(
        # Use uniform distribution Unif(-1, 1) to initialize.
        # This initialization doesn't matter.
        # It will be substituted by i_features_op.
        np.random.uniform(low = -1.0,
                          high = 1.0, 
                          size = [batch_size, 128]
        ).astype('float32'),
        name='i_features'
    )
    i_features_place = tf.placeholder(dtype = i_features.dtype, 
                                      shape = i_features.get_shape())
    i_features_op    = i_features.assign(i_features_place)
    i_dcgan_inits    = dcgan_generator(n_samples = batch_size, 
                                       noise     = i_features, 
                                       dim = 64, ln = True)

    return i_features, i_features_op, i_features_place, i_dcgan_inits

In [None]:
def extract_matching_images(batch_real_labels, neg_images, neg_labels):
    batch_fake_images = []
    batch_fake_labels = []
    for real_label in batch_real_labels:
        select = np.random.randint(0, neg_labels[real_label].shape[0])
        batch_fake_images.append(neg_images[real_label][select])
        batch_fake_labels.append(real_label)
    return np.array(batch_fake_images), np.array(batch_fake_labels)

In [None]:
def main(sess):
    # Load images from dataset and normalize.
    unnorm_all_train_images, all_train_labels = \
        load_train_data()
    all_train_images = normalize(unnorm_all_train_images)
    train_images, train_labels         = \
        all_train_images[:45000], all_train_labels[:45000]
    val_images, val_labels   = \
        all_train_images[45000:], all_train_labels[45000:]
    unnorm_test_images, test_labels = \
        load_test_data()
    test_images = normalize(unnorm_test_images)
    
    # Set count of positive images in classifier training in one iteration.
    # In fact, we will use the same images in each iteration. 
    # Besides, the count of negative images in classifier training is the 
    # same as positive images.
    d_pos_images_count    = train_images.shape[0]
    
    # Set count of pseudo negative images to synthesize in one iteration.
    g_neg_images_count    = 1000
    g_neg_batches_per_cat = g_neg_images_count // (batch_size * cats)
    
    # Prepare for training set with specific number.
    if d_pos_images_count == train_images.shape[0]:
        train_images, train_labels = train_images, train_labels
    else:
        train_images, train_labels = extract_data(train_images, train_labels, 
                                                  d_pos_images_count)
    
    # Create directories for pseudo negative images and models.
    if not os.path.exists(root_dir):
        os.mkdir(root_dir)
    neg_dir   = os.path.join(root_dir, 'neg')
    model_dir = os.path.join(root_dir, 'model')
    if not os.path.exists(neg_dir):
        os.mkdir(neg_dir)
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)
    
    # Set log file and write experiment settings.
    log_file_path = os.path.join(root_dir, 'log.txt')
    log(log_file_path, 'Mode: {} (0 for baseline, 1 for WINN).'.format(mode))
    log(log_file_path, ('Images: {} for training, {} for validation, ' + 
                        '{} for test.').format(
                           len(train_images), len(val_images), 
                           len(test_images)))
    log(log_file_path, 'Epochs: {}.'.format(epochs))
    log(log_file_path, 'Batch size: {}.'.format(batch_size))
    log(log_file_path, 'Number of critics: {}.'.format(critics))
    log(log_file_path, 'Max. optimizing steps: {}.'.format(max_opt_steps)) 
    
    # Build DCGAN initialization operators.
    i_features, i_features_op, i_features_place, i_dcgan_inits = \
        build_dcgan_init_op()
    
    # Build training operators.
    [d_real_images_place, d_fake_images_place, 
     d_real_labels_place, d_fake_labels_place, 
     d_loss, d_softmax_loss, d_wass_loss, d_real_loss, d_fake_loss, 
     d_train_op] = \
        build_train_op()

    # Build synthesis operators.
    [g_images_place, g_labels_place, g_thres_place, 
     g_images, g_images_op, g_loss, g_softmax_loss, 
     g_synthesis_op, g_abs_wass_dist] = \
        build_synthesis_op()

    # Build validation/test operators.
    [c_images_place, c_labels_place, c_softmax_losses, 
     c_loss, c_preds, c_acc] = \
        build_test_op()

    # After creating network, build saver.
    saver = tf.train.Saver(max_to_keep = 100)

    # Initialize all variables.
    all_initializer_op = tf.global_variables_initializer()
    sess.run(all_initializer_op)
    
    # Prepare initial pseudo negative images.
    def get_dcgan_init():
        i_batch_features = \
            np.random.uniform(low = -1.0, high = 1.0, size = (batch_size, 128))
        sess.run(i_features_op, {i_features_place: i_batch_features})         
        i_images = sess.run(i_dcgan_inits)
        return i_images
        
    init_dir = os.path.join(neg_dir, 'init')
    if not os.path.exists(init_dir):
        os.mkdir(init_dir)
    
    # Create initial pseudo negative images.
    init_images = []
    init_labels = []
    for cat in xrange(cats):
        cat_init_images = []
        for _ in xrange(g_neg_batches_per_cat):
            batch_images = get_dcgan_init()
            cat_init_images.append(batch_images)
        cat_init_images = np.concatenate(cat_init_images, axis = 0)  
        # Add normalized initial images.
        init_images.append(cat_init_images)
        init_labels.append(np.full(g_neg_batches_per_cat * batch_size, cat))
        assert(len(init_images[-1]) == len(init_labels[-1]))
    
    # Save all initial images.
    for cat in xrange(cats):
        unnorm_cat_init_images = unnormalize(init_images[cat])
        for batch in xrange(unnorm_cat_init_images.shape[0] // batch_size):
            images = unnorm_cat_init_images[   batch    * batch_size : 
                                            (batch + 1) * batch_size]
            images = images.reshape((batch_size, 28, 28))
            path   = os.path.join(init_dir, 
                                  'cat_{}_batch_{}.png'.format(cat, batch))
            save_batch_images_to_path(image_shape, images, path)
    
    # Add normalized initial images into pseudo negative images.
    neg_images = init_images
    neg_labels = init_labels
    
    # Log all global variables.
    global_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='')
    log(log_file_path, 'Global variables:')
    for i,var in enumerate(global_vars):
        log(log_file_path, '{}, {}, {}.'.format(i, var.name, var.get_shape()))
        
    # Main procedure.
    val_acc_list  = []
    test_acc_list = []
    for epoch in xrange(epochs):
        log(log_file_path, '[{}] Epoch {}.'.format(str(datetime.now()), epoch))
        shuffle = np.random.permutation(d_pos_images_count)
        
        # Classifier training.
        for critic in xrange(critics):
            
            d_critic_loss             = 0.0
            d_batches_per_critic      = d_pos_images_count // batch_size
            d_batch_real_loss_list    = []
            
            for batch in xrange(d_batches_per_critic):
                
                # Positive images to train.
                d_batch_real_images = \
                              train_images[shuffle[   batch    * batch_size : 
                                                   (batch + 1) * batch_size]]
                d_batch_real_labels = \
                              train_labels[shuffle[   batch * batch_size : 
                                                   (batch + 1) * batch_size]]
                # Pseudo negative images to train.
                d_batch_fake_images, d_batch_fake_labels = \
                    extract_matching_images(d_batch_real_labels, 
                                            neg_images, neg_labels)
                
                # Training step.
                [_, d_batch_loss, d_batch_softmax_loss, 
                 d_batch_wass_loss, d_batch_real_loss, d_batch_fake_loss] = \
                    sess.run([d_train_op, d_loss, d_softmax_loss, 
                              d_wass_loss, d_real_loss, d_fake_loss], 
                             feed_dict = \
                                 {d_real_images_place: d_batch_real_images, 
                                  d_fake_images_place: d_batch_fake_images,
                                  d_real_labels_place: d_batch_real_labels,
                                  d_fake_labels_place: d_batch_fake_labels})
                
                # Record the loss.
                d_batch_real_loss_list.append(d_batch_real_loss)
                d_critic_loss += (d_batch_loss / d_batches_per_critic)
                
            # Log some losses.
            log(log_file_path, 
                '[{}] Critic {}: Discriminator loss {}.'.format(
                    str(datetime.now()), critic, d_critic_loss))
            log(log_file_path,
                ('[{}] Last batch of critic {}: Discriminator loss {}, ' + 
                 'softmax loss {}, Wass. loss {}, ' + 
                 'real loss {}, fake loss {}.').format(
                    str(datetime.now()), critic, d_batch_loss,
                    d_batch_softmax_loss, d_batch_wass_loss, 
                    d_batch_real_loss, d_batch_fake_loss))
        
        # Save the model every several epoches.
        if epoch % 20 == 0:
            checkpoint_path = os.path.join(model_dir, 
                                           'epoch_{}_model.ckpt'.format(epoch))
            saver.save(sess, checkpoint_path)
            
        # Synthesizing new pseudo negatives. Only enabled under WINN mode.
        if mode == 1:
            # Directory of synthesized images for current epoch.
            neg_epoch_dir = os.path.join(neg_dir, 'epoch_{}'.format(epoch))
            if not os.path.exists(neg_epoch_dir):
                os.mkdir(neg_epoch_dir)
            
            for cat in xrange(cats):            
                for batch in xrange(g_neg_batches_per_cat):
                    g_batch_init_images = get_dcgan_init()
                    g_batch_labels = np.full(batch_size, cat)
                    
                    sess.run(g_images_op, feed_dict = {g_images_place: 
                                                       g_batch_init_images})
                    
                    # Choose threshold based on previous real losses.
                    g_thres = sess.run(tf.random_uniform(
                        [],
                        minval = min(d_batch_real_loss_list),
                        maxval = max(d_batch_real_loss_list)
                    ))
        
                    for step in xrange(max_opt_steps):
                        # Take one step of synthesis.
                        sess.run(g_synthesis_op, 
                                 feed_dict={g_labels_place: g_batch_labels,
                                            g_thres_place:  g_thres})
                        # Clip the synthesized images.
                        sess.run(g_images_op, 
                                 feed_dict={g_images_place: 
                                            np.clip(sess.run(g_images), 
                                                    -1.0, 1.0)})
                        # Fetch the loss.
                        g_batch_loss, g_batch_abs_wass_dist = \
                            sess.run([g_loss, g_abs_wass_dist],
                                feed_dict = {g_labels_place: g_batch_labels,
                                             g_thres_place:  g_thres})
                        
                        # Use (absolute) Wasserstein distance as target.
                        if g_batch_abs_wass_dist <= 1e-3:
                            break
                    
                    g_batch_images = sess.run(g_images)
                    neg_images[cat] = np.concatenate(
                        (neg_images[cat], g_batch_images), axis = 0)
                    neg_labels[cat] = np.concatenate(
                        (neg_labels[cat], g_batch_labels), axis = 0)
                    
                    # Save synthesized images.
                    images = g_batch_images.reshape((batch_size, 28, 28))
                    path   = os.path.join(neg_epoch_dir, 
                             'cat_{}_batch_{}.png'.format(cat, batch))
                    save_batch_images_to_path(image_shape, images, path)
                    
                    log(log_file_path, 
                        ('[{}] Generator loss {}, abs. Wass. dist. {}, ' + 
                         'opt. steps {}.').format(
                        str(datetime.now()), g_batch_loss, 
                            g_batch_abs_wass_dist, step))
        
        # Validation.
        val_loss, val_acc = 0.0, 0.0
        val_batches = val_images.shape[0] // batch_size
        for batch in xrange(val_batches):
            val_batch_images = val_images[   batch    * batch_size : 
                                          (batch + 1) * batch_size]
            val_batch_labels = val_labels[   batch    * batch_size :
                                          (batch + 1) * batch_size]
            
            val_batch_loss, val_batch_acc = \
                sess.run([c_loss, c_acc], 
                    feed_dict={c_images_place: val_batch_images, 
                               c_labels_place: val_batch_labels})
                
            # Batch size should be divisible by count of validation images.
            val_loss += (val_batch_loss / val_batches)
            val_acc  += (val_batch_acc  / val_batches)

        # Save model with max. val. acc.
        val_acc_list.append(val_acc)
        if val_acc >= max(val_acc_list):
            checkpoint = os.path.join(model_dir, 
                                      'epoch_{}_model.ckpt'.format(epoch))
            saver.save(sess, checkpoint)
            
        # Test.
        test_loss, test_acc = 0.0, 0.0
        test_batches = test_images.shape[0] // batch_size
        for batch in xrange(test_batches):
            test_batch_images = test_images[   batch    * batch_size : 
                                            (batch + 1) * batch_size]
            test_batch_labels = test_labels[   batch    * batch_size :
                                            (batch + 1) * batch_size]
            test_batch_loss, test_batch_acc = \
                sess.run([c_loss, c_acc], 
                    feed_dict={c_images_place: test_batch_images, 
                               c_labels_place: test_batch_labels})
                
            # Batch size should be divisible by count of test images.
            test_loss += (test_batch_loss / test_batches)
            test_acc  += (test_batch_acc  / test_batches)

        # Save model with max. test acc.
        test_acc_list.append(test_acc)
        if test_acc >= max(test_acc_list):
            checkpoint = os.path.join(model_dir, 
                                      'epoch_{}_model.ckpt'.format(epoch))
            saver.save(sess, checkpoint)
            
        log(log_file_path, 
            '[{}] Val. loss {}, val. error {}, min. val. error {}.'.format(
            str(datetime.now()), val_loss, 1 - val_acc, 
            1 - max(val_acc_list)))
        log(log_file_path, 
            '[{}] Test loss {}, test error {}, min test error {}.'.format(
            str(datetime.now()), test_loss, 1 - test_acc, 
            1 - max(test_acc_list)))
        
    log(log_file_path, 'Min test error {}.'.format(1 - max(test_acc_list)))
    log(log_file_path, 'Min val. error {}.'.format(1 - max(val_acc_list)))

In [None]:
if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
    # Session configuration.
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True

    graph = tf.Graph()
    with graph.as_default():
        with tf.device('/gpu:0'):
            with tf.Session(config=config) as sess:
                with tf.variable_scope('WINN', reuse=None):
                    main(sess)