# Unsupervised Image-image Translation(UNIT) TensorFlow Implementation

In [1]:
import tensorflow as tf
from collections import OrderedDict

  return f(*args, **kwds)


In [2]:
#Funciton to chose the activation function
def activate(linear, activation='leaky_relu'):
        if activation == 'sigmoid':
            return tf.nn.sigmoid(linear)
        elif activation == 'softmax':
            return tf.nn.softmax(linear)
        elif activation == 'tanh':
            return tf.nn.tanh(linear)
        elif activation == 'relu':
            return tf.nn.relu(linear)
        elif activation == 'leaky_relu':
            return tf.nn.leaky_relu(linear)
        elif activation == 'linear':
            return linear

In [3]:
def gaussian_noise_layer(input_layer, std=1.0):
    noise = tf.random_normal(shape=tf.shape(input_layer), mean=0.0, stddev=std, dtype=tf.float32) 
    return tf.add(input_layer, noise)

In [4]:
def batch_normalization_layer(input_layer):
    
    dimension = input_layer.get_shape().as_list()[-1]
    
    mean, variance = tf.nn.moments(input_layer, axes=[0, 1, 2])
    
    beta = tf.Variable(tf.constant(0.0, shape=[dimension]))
    
    gamma = tf.Variable(tf.constant(1.0, shape=[dimension]))

    bn_layer = tf.nn.batch_normalization(input_layer, mean, variance, beta, gamma, variance_epsilon=1e-10)

    return bn_layer

In [408]:
def create_conv_layer(input_layer,     # The previous layer.
                   filter_size,        # Width and height of each filter.
                   num_filters,        # Number of filters.
                   use_pooling=False,  
                   pad = [],
                   strides=[1,1,1,1],
                   deconv=False,
                   out_shape = [],      # Output shape in case of deconv 
                   batch_normalization=False,
                   activation='leaky_relu'): # Use 2x2 max-pooling.

    
    num_input_channels = input_layer.get_shape().as_list()[-1]
    # Shape of the filter-weights for the convolution. 
    # This format is determined by the TensorFlow API.
    if deconv:
        shape = [filter_size, filter_size, num_filters, num_input_channels]
    else:
        shape = [filter_size, filter_size, num_input_channels, num_filters]

    # Create new weights aka. filters with the given shape.
    weights = tf.Variable(tf.truncated_normal(shape, stddev=0.05))

    # Create new biases, one for each filter.
    biases = tf.Variable(tf.constant(1.0, shape=[num_filters]))

    if len(pad) > 0:
        input_layer = tf.pad(input_layer, [[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]], "CONSTANT")
    
    # Create the TensorFlow operation for de-convolution.
    if deconv:
        in_shape = input_layer.get_shape().as_list()
        #in_shape = tf.shape(input_layer)
    
        out_h = ((in_shape[1] - 1) * strides[1]) + filter_size - 2 * pad[0]
        
        out_w = ((in_shape[2] - 1) * strides[2]) + filter_size - 2 * pad[1]
        
        output_shape = tf.stack([in_shape[0], out_h, out_w, num_filters])
        print("in:",in_shape,"out:",[in_shape[0], out_h, out_w, num_filters])
        layer = tf.nn.conv2d_transpose(value=input_layer, 
                                       filter=weights, 
                                       output_shape=output_shape, 
                                       strides=strides, 
                                       padding='SAME')
        
    else:
        layer = tf.nn.conv2d(input=input_layer,
                         filter=weights,
                         strides=strides,
                         padding='VALID')

    # Add the biases to the results of the convolution.
    # A bias-value is added to each filter-channel.
    layer = tf.add(layer,biases)

    # Use pooling to down-sample the image resolution?
    if use_pooling:
        # This is 2x2 max-pooling, which means that we
        # consider 2x2 windows and select the largest value
        # in each window. Then we move 2 pixels to the next window.
        layer = tf.nn.max_pool(value=layer,
                               ksize=[1, 2, 2, 1],
                               strides=[1, 2, 2, 1],
                               padding='SAME')

    # Batch normalization
    if batch_normalization:
        layer = batch_normalization_layer(layer)
    
    # Activation of the layers (ReLU).
    layer = activate(layer, activation=activation)

    return layer

In [409]:
def create_res_block(X, filter_size=3, num_filters=3, pad=[1,1]):
    
    num_filters = X.get_shape().as_list()[-1]
    
    layer_conv1 = create_conv_layer(input_layer=X,
                                   filter_size=filter_size,
                                   num_filters=num_filters,
                                   pad=pad,
                                   use_pooling=False,
                                   batch_normalization=True,
                                   activation='relu')
    
    layer_conv2 = create_conv_layer(input_layer=layer_conv1,
                                   filter_size=filter_size,
                                   num_filters=num_filters,
                                   pad=pad,
                                   use_pooling=False,
                                   batch_normalization=False,
                                   activation='linear')
    
    layer_conv2 = batch_normalization_layer(layer_conv2)
    
    layer_conv2 += X
    
    layer_res = activate(layer_conv2, activation='relu')
    
    return layer_res

In [410]:
def create_encoder(X, layer_n, res_block_n):
    
    num_filters = X.get_shape().as_list()[-1]
    
    encoder = OrderedDict()
    
    encoder['layer_conv1'] = create_conv_layer(input_layer=X,
                                       num_filters=num_filters,
                                       filter_size=7,
                                       strides=[1,1,1,1],
                                       pad=[3,3])
    
    for i in range(1, layer_n):
        
        encoder['layer_conv'+str(i+1)] = create_conv_layer(encoder[next(reversed(encoder))],
                                       num_filters=num_filters*2,
                                       filter_size=3,
                                       strides=[1,2,2,1],
                                       pad=[1,1])
        
        num_filters = num_filters*2
        
    for i in range(0, res_block_n):
        encoder['block_en_res'+str(i+1)] = create_res_block(encoder[next(reversed(encoder))],
                                                           num_filters=num_filters)
        
    return encoder#[next(reversed(encoder))]

In [411]:
def create_shared_layers(X, block_shared_n):
    encoder_shared = OrderedDict()
    decoder_shared = OrderedDict()
    
    encoder_shared_ = create_res_block(X)
    encoder_shared['block_shared_res1'] = gaussian_noise_layer(encoder_shared_)
    
    for i in range(1, block_shared_n+1):
        encoder_shared_ = create_res_block(encoder_shared[next(reversed(encoder_shared))])
        encoder_shared['block_shared_res'+str(i+1)] = gaussian_noise_layer(encoder_shared_)
    
    decoder_shared['block_shared_res1'] = create_res_block(encoder_shared[next(reversed(encoder_shared))])
    
    for i in range(1, block_shared_n+1):
        decoder_shared['block_shared_res'+str(i+1)] = create_res_block(decoder_shared[next(reversed(decoder_shared))])
        
    return decoder_shared#[next(reversed(decoder_shared))]

In [420]:
def create_generator(X, layer_n, res_block_n, output_shape):
    
    num_filters = X.get_shape().as_list()[-1]
    
    decoder = OrderedDict()
    
    decoder['block_res1'] = create_res_block(X, num_filters=num_filters)
        
    for i in range(1, res_block_n):
        decoder['block_res'+str(i+1)] = create_res_block(decoder[next(reversed(decoder))],
                                                           num_filters=num_filters)

    for i in range(0, layer_n-1):
        
        decoder['layer_deconv'+str(i+1)] = create_conv_layer(decoder[next(reversed(decoder))],
                                       num_filters=num_filters,
                                       filter_size=3,
                                       strides=[1,2,2,1],
                                       pad=[1,1],
                                       deconv=True,
                                       batch_normalization=True,
                                       activation='relu')
        
        num_filters = num_filters//2
    
    decoder['layer_deconv'+str(layer_n)] = tf.nn.conv2d_transpose(decoder[next(reversed(decoder))],
                                       filter=tf.Variable(tf.truncated_normal([1,1,], stddev=0.05)),
                                       output_shape=output_shape, 
                                       strides=[1,1,1,1], 
                                       padding='SAME')
    
    return decoder#[next(reversed(decoder))]

In [421]:
def create_discrimiator(X, num_filters, layer_n):
    
    discriminator = OrderedDict()
    
    discriminator['layer_discrim_conv1'] = create_conv_layer(input_layer=X,
                                       num_filters=num_filters,
                                       filter_size=3,
                                       strides=[1,2,2,1],
                                       pad=[1,1])
    
    for i in range(1, layer_n):
        
        num_filters = num_filters * 2
        
        discriminator['layer_discrim_conv'+str(i+1)] = create_conv_layer(discriminator[next(reversed(discriminator))],
                                       num_filters=num_filters,
                                       filter_size=3,
                                       strides=[1,2,2,1],
                                       pad=[1,1])
    
    discriminator['layer_deconv_final'] = create_conv_layer(discriminator[next(reversed(discriminator))],
                                       num_filters=1,
                                       filter_size=1,
                                       strides=[1,1,1,1],
                                       pad=[0,0])
    
    return discriminator#[next(reversed(discriminator))]

In [422]:
def compute_encoding_loss(x):
    x_2 = tf.pow(x, 2)
    encoding_loss = tf.reduce_mean(x_2)
    return encoding_loss

In [423]:
def compute_l1_loss(x,y):
    return tf.reduce_mean(tf.subtract(x,y))

In [424]:
#Build Graph
X_A = tf.placeholder(tf.float32, shape=[1, 256, 256, 3], name='X_A')
X_B = tf.placeholder(tf.float32, shape=[1, 256, 256, 3], name='X_B')

encode_A = create_encoder(X_A, layer_n=3, res_block_n=3)
encode_B = create_encoder(X_B, layer_n=3, res_block_n=3)

encode_AB = tf.concat([encode_A[next(reversed(encode_A))], encode_B[next(reversed(encode_B))]], axis=1)

shared = create_shared_layers(encode_AB, block_shared_n=3)

decode_A = create_generator(shared[next(reversed(shared))], layer_n=3, res_block_n=3, output_shape=X_A.get_shape().as_list())
decode_B = create_generator(shared[next(reversed(shared))], layer_n=3, res_block_n=3, output_shape=X_A.get_shape().as_list())

X_aa, X_ba = tf.split(decode_A[next(reversed(decode_A))], num_or_size_splits=X_A.get_shape().as_list()[1], axis=1)
X_ab, X_bb = tf.split(decode_B[next(reversed(decode_A))], num_or_size_splits=X_B.get_shape().as_list()[1], axis=1)

#dis
outs_a = create_discrimiator(X_ba, num_filters=3, layer_n=6)
outs_b = create_discrimiator(X_ab, num_filters=3, layer_n=6)

in: [1, 130, 66, 12] out: [1, 259, 131, 12]
in: [1, 263, 135, 12] out: [1, 525, 269, 6]
in: [1, 130, 66, 12] out: [1, 259, 131, 12]
in: [1, 263, 135, 12] out: [1, 525, 269, 6]


ValueError: Dimension size must be evenly divisible by 256 but is 527
	Number of ways to split should evenly divide the split dimension for 'split_8' (op: 'Split') with input shapes: [], [1,527,271,3] and with computed input tensors: input[0] = <1>.

In [425]:
encode_A

OrderedDict([('layer_conv1',
              <tf.Tensor 'LeakyRelu_84/Maximum:0' shape=(1, 256, 256, 3) dtype=float32>),
             ('layer_conv2',
              <tf.Tensor 'LeakyRelu_85/Maximum:0' shape=(1, 128, 128, 6) dtype=float32>),
             ('layer_conv3',
              <tf.Tensor 'LeakyRelu_86/Maximum:0' shape=(1, 64, 64, 12) dtype=float32>),
             ('block_en_res1',
              <tf.Tensor 'Relu_551:0' shape=(1, 64, 64, 12) dtype=float32>),
             ('block_en_res2',
              <tf.Tensor 'Relu_553:0' shape=(1, 64, 64, 12) dtype=float32>),
             ('block_en_res3',
              <tf.Tensor 'Relu_555:0' shape=(1, 64, 64, 12) dtype=float32>)])

In [426]:
shared

OrderedDict([('block_shared_res1',
              <tf.Tensor 'Relu_571:0' shape=(1, 128, 64, 12) dtype=float32>),
             ('block_shared_res2',
              <tf.Tensor 'Relu_573:0' shape=(1, 128, 64, 12) dtype=float32>),
             ('block_shared_res3',
              <tf.Tensor 'Relu_575:0' shape=(1, 128, 64, 12) dtype=float32>),
             ('block_shared_res4',
              <tf.Tensor 'Relu_577:0' shape=(1, 128, 64, 12) dtype=float32>)])

In [427]:
decode_A

OrderedDict([('block_res1',
              <tf.Tensor 'Relu_579:0' shape=(1, 128, 64, 12) dtype=float32>),
             ('block_res2',
              <tf.Tensor 'Relu_581:0' shape=(1, 128, 64, 12) dtype=float32>),
             ('block_res3',
              <tf.Tensor 'Relu_583:0' shape=(1, 128, 64, 12) dtype=float32>),
             ('layer_deconv1',
              <tf.Tensor 'Relu_584:0' shape=(1, 261, 133, 12) dtype=float32>),
             ('layer_deconv2',
              <tf.Tensor 'Relu_585:0' shape=(1, 527, 271, 6) dtype=float32>),
             ('layer_deconv3',
              <tf.Tensor 'Tanh_6:0' shape=(1, 527, 271, 3) dtype=float32>)])

In [None]:
def forward(sess, X_A, X_B, hyperparameters):
    
    #a2b
    X_bab = sess.run([encode_A], feed_dict={X:X_ba})
    shared_bab = sess.run([shared], feed_dict={X:X_bab})
    X_bab = sess.run([decode_B], feed_dict={X:shared_bab})
    
    #b2a
    X_aba = sess.run([encode_B], feed_dict={X:X_ab})
    shared_aba = sess.run([shared], feed_dict={X:X_aba})
    X_aba = sess.run([decode_A], feed_dict={X:shared_aba})
    
    #loss
    all_ones = tf.Variable(tf.constant(1.0, shape=[tf.shape(outs_a)[0]]))
    
    ad_loss_a = tf.nn.sigmoid_cross_entropy_with_logits(logits=outs_a, labels=all_ones)
    ad_loss_b = tf.nn.sigmoid_cross_entropy_with_logits(logits=outs_b, labels=all_ones)
    
    enc_loss  = compute_encoding_loss(shared)
    enc_bab_loss = compute_encoding_loss(shared_bab)
    enc_aba_loss = compute_encoding_loss(shared_aba)
    ll_loss_a = compute_l1_loss(X_aa, X_A)
    ll_loss_b = compute_l1_loss(X_bb, X_B)
    ll_loss_aba = compute_l1_loss(X_aba, X_A)
    ll_loss_bab = compute_l1_loss(X_bab, X_B)
    
    total_loss = hyperparameters['gan_w'] * (ad_loss_a + ad_loss_b) + \
                 hyperparameters['ll_direct_link_w'] * (ll_loss_a + ll_loss_b) + \
                 hyperparameters['ll_cycle_link_w'] * (ll_loss_aba + ll_loss_bab) + \
                 hyperparameters['kl_direct_link_w'] * (enc_loss + enc_loss) + \
                 hyperparameters['kl_cycle_link_w'] * (enc_bab_loss + enc_aba_loss)

In [15]:
tf.reset_default_graph()

with tf.Session() as sess:
    

In [307]:
tf.reset_default_graph()
# create_UNIT(tf.placeholder(tf.float32, shape=[None, 256, 256, 3], name='X_A'),
#            tf.placeholder(tf.float32, shape=[None, 256, 256, 3], name='X_B'))