# Texture Synthesis with Spatial Generative Adversarial Networks

[Paper](https://arxiv.org/pdf/1611.08207v2.pdf)
[Sample implementation](https://github.com/ubergmann/spatial_gan)


In [None]:
import tensorflow as tf
import numpy as np
from nbutil import imshow_multi, to_pil
from tensorflow.contrib.layers.python.layers import batch_norm
from tensorflow.contrib.layers import xavier_initializer
import skimage
import skimage.io

In [None]:
BATCH_SIZE = 32
OUTPUT_WIDTH, OUTPUT_HEIGHT = (256, 256)
R = 16
Z_WIDTH, Z_HEIGHT = (OUTPUT_WIDTH / R, OUTPUT_HEIGHT / R)
Z_DEPTH = 20

In [None]:
%matplotlib inline

def load_image(path):
    # load image
    img = skimage.io.imread(path)
    img = img[:,:,:3] # drop alpha channel
    return img / 255.0

image = load_image('../data/bk.png')
imshow_multi([image])


In [None]:
def _phase_shift(I, r):
    # Helper function with main phase shift operation
    bsize, a, b, c = I.get_shape().as_list()
    X = tf.reshape(I, (bsize, a, b, r, r))
    X = tf.transpose(X, (0, 1, 2, 4, 3))  # bsize, a, b, 1, 1
    X = tf.split(1, a, X)  # a, [bsize, b, r, r]
    X = tf.concat(2, [tf.squeeze(x) for x in X])  # bsize, b, a*r, r
    X = tf.split(1, b, X)  # b, [bsize, a*r, r]
    X = tf.concat(2, [tf.squeeze(x) for x in X])  #
    bsize, a*r, b*r
    return tf.reshape(X, (bsize, a*r, b*r, 1))

def PS(X, r, color=False):
    # Main OP that you can arbitrarily use in you tensorflow code
    if color:
        Xc = tf.split(3, 3, X)
        X = tf.concat(3, [_phase_shift(x, r) for x in Xc])
    else:
        X = _phase_shift(X, r)
    return X

In [None]:
dropout_keep_prob = tf.placeholder_with_default(tf.constant(1.0), [], name='dropout_keep_prob')        

def lrelu(x):
    alpha = 0.05
    return tf.maximum(alpha*x, x)

def create_batch_norm(inputs, name='bn'):
    with tf.variable_scope(name):
        return batch_norm(inputs, is_training=True, updates_collections=None)

def create_avg_pool(inputs, ksize=2, stride=2):
    return tf.nn.avg_pool(inputs, ksize=[1, ksize, ksize, 1], strides=[1, stride, stride, 1], padding='SAME')
    
def create_dropout(inputs):
    return tf.nn.dropout(inputs, dropout_keep_prob)

def create_conv(input, out_channels, patch_size=5, stride=1, use_relu=True, name='conv'):
    with tf.variable_scope(name):
        in_channels = input.get_shape()[-1].value
        # w = weight_var([patch_size, patch_size, in_channels, out_channels], name='w', key=join_keys(key, 'w'))
        # b = weight_var([out_channels], stddev=0, name='b', mean=0.1, key=join_keys(key, 'b'))
        w = tf.get_variable('w', 
                            shape=[patch_size, patch_size, in_channels, out_channels], 
                            initializer=xavier_initializer())
        b = tf.get_variable('b',
                           shape=[out_channels])
        conv = tf.nn.conv2d(input, w, strides=[1,stride,stride,1], padding='SAME')
        activation = lrelu(conv + b) if use_relu else conv + b
        return activation

def create_deconv(input, out_channels, patch_size=5, stride=1, use_relu=True, name='deconv'):
    with tf.variable_scope(name):
        # for best results, patch_size should be a multiple of stride
        input_w, input_h, input_channels = [i.value for i in input.get_shape()[-3:]]

        # w = weight_var([patch_size, patch_size, out_channels, input_channels])
        # b = weight_var([out_channels], mean=0.1)
        w = tf.get_variable('w', 
                            shape=[patch_size, patch_size, out_channels, input_channels], 
                            initializer=xavier_initializer())
        b = tf.get_variable('b',
                            shape=[out_channels])

        batch_size = BATCH_SIZE # tf.shape(input)[0]
        output_shape = tf.pack([batch_size, input_w*stride, input_h*stride, out_channels])

        deconv = tf.nn.conv2d_transpose(input, w, output_shape, strides=[1,stride,stride,1], padding='SAME')

        activation = lrelu(deconv + b) if use_relu else deconv + b
        return activation


In [None]:
source_image = tf.constant(image, tf.float32)
def rand_crop(): return tf.random_crop(source_image, [OUTPUT_WIDTH, OUTPUT_HEIGHT, 3])
real_textures = [rand_crop() for _ in xrange(BATCH_SIZE)]

def generator(noise):
    with tf.variable_scope('generator'):
        image = noise
        # strides must multiply to R (16)
        layers = [256, 128, 64, 3]
        for i, channels in enumerate(layers):
            is_last_layer = i == len(layers)-1
            image = create_deconv(image, 
                                  channels, 
                                  size=5, 
                                  stride=2, 
                                  name='deconv'+str(i), 
                                  use_relu=(not is_last_layer))
            if not is_last_layer:
                image = create_batch_norm(image, name='bn'+str(i))
        assert [d.value for d in image.get_shape()[1:]] == [OUTPUT_WIDTH, OUTPUT_HEIGHT, 3]
        return image

def discriminator(textures):
    # input: [BATCH_SIZE, IMAGE_WIDTH, IMAGE_HEIGHT, 3] textures
    # output: [BATCH_SIZE, Z_WIDTH, Z_HEIGHT, 2]
    #  where the last dim is the logit probability of this texture being real
    with tf.variable_scope('discriminator'):
        image = textures
        # some big convolutional layers:
        for i, channels in enumerate([256, 128, 64, 2]):
            image = create_conv(image, channels, stride=1, patch_size=5, name='conv'+str(i))
            image = create_avg_pool(image)
            image = create_batch_norm(image, name='bn'+str(i))   
        # a couple 1x1 convolutions:
        for i, channels in enumerate([2]):
            image = create_conv(image, channels, stride=1, patch_size=1, name='1x1conv'+str(i))
        assert [d.value for d in image.get_shape()[1:]] == [Z_WIDTH, Z_HEIGHT, 2]
        return image

scopename = '3'
with tf.variable_scope(scopename):

    noise = tf.random_uniform([BATCH_SIZE, Z_WIDTH, Z_HEIGHT, Z_DEPTH], minval=-1, maxval=1)
    synthetic_textures = generator(noise)

    disc_input = tf.concat(0, [real_textures, synthetic_textures])
    disc_target = tf.concat(0, [tf.ones([BATCH_SIZE, Z_WIDTH, Z_HEIGHT], tf.int32), tf.zeros([BATCH_SIZE, Z_WIDTH, Z_HEIGHT], tf.int32)])
    
    disc_output = discriminator(disc_input)
    discriminator_loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(disc_output, disc_target))

    disc_guess = tf.argmax(disc_output, 3)
    disc_correct = tf.equal(tf.cast(disc_guess, tf.int32), disc_target)
    disc_accuracy = tf.reduce_mean(tf.cast(disc_correct, tf.float32))

    disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scopename+'/discriminator')
    gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scopename+'/generator')
    # for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    #     print v.name

    global_step = tf.contrib.framework.get_or_create_global_step()

    train_gen = tf.train.AdamOptimizer(0.0001).minimize(-discriminator_loss, global_step=global_step, var_list=gen_vars)
    train_disc = tf.train.AdamOptimizer(0.0001).minimize(discriminator_loss, global_step=global_step, var_list=disc_vars)


In [None]:
session = tf.InteractiveSession()

save_path = 'models/sgan256-2-1'

init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
session.run(init_op)

import os
saver = None
if save_path:
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    saver = tf.train.Saver()
    ckpt = tf.train.get_checkpoint_state(save_path)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(session, ckpt.model_checkpoint_path)
        print 'Restored from checkpoint', ckpt.model_checkpoint_path
    else:
        print 'Did not restore from checkpoint'
else:
    print 'Will not save progress'

In [None]:
def avg(x): return sum(x) / float(len(x))
disc_accuracies = []
losses = []
last_saved_loss = None

while True:
    only_train_gen = len(disc_accuracies) > 0 and avg(disc_accuracies) > 0.8
    only_train_disc = len(disc_accuracies) > 0 and avg(disc_accuracies) < 0.6
    n_gen_runs = 0 if only_train_disc else 1
    n_disc_runs = 0 if only_train_gen else 1
        
    for _ in range(n_gen_runs):
        disc_acc_, step_, loss_, _ = session.run([disc_accuracy, global_step, discriminator_loss, train_gen])
        losses.append(loss_)
        disc_accuracies.append(disc_acc_)
    
    for _ in range(n_disc_runs):
        feed = {dropout_keep_prob: 0.5}
        disc_acc_, step_, loss_, _ = session.run([disc_accuracy, global_step, discriminator_loss, train_disc], feed_dict=feed)
        losses.append(loss_)
        disc_accuracies.append(disc_acc_)
    
    step_rounded = int(step_ / 2) * 2
    if step_rounded % 50 == 0:
        print "Step: {}, loss: {}, disc accuracy: {}".format(step_rounded, avg(losses), avg(disc_accuracies))
        
        if step_rounded % 200 == 0 and saver:
            should_save = True
            if should_save:
                saver.save(session, save_path + '/model.ckpt', global_step=step_rounded)
                print 'Saved'
            else:
                pass
                # print 'Loss did not decrease from previous save, so not saving'
        
        disc_accuracies = []
        losses = []


In [None]:
%matplotlib inline
def generate_sample():
    textures, real = session.run([synthetic_textures, real_textures])
    imshow_multi(list(np.clip(textures[:3], 0, 1)) + list(real[:1]))
    return to_pil(textures[0])
generate_sample()