In [None]:
%matplotlib inline
from skimage import io, filters
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from lgan.diffeomorphism import tf_diffeomorphism
from tqdm import tqdm

In [None]:
batch_size = 32 #batch size
input_dim = 128 #dim x and z input_size

#mnist data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
X_train = mnist.train.images
Y_train = mnist.train.labels
X_train = X_train.reshape(55000,28,28,1)

#reorganize labels
y_train = np.zeros(55000)
for i in range(len(Y_train)):
    y_train[i]= np.argwhere(Y_train[i]==1)#reshape

#create mean Images
numberImages = np.zeros((10,28,28,1))
for i in range(10):
    numbers = np.argwhere(y_train==i)
    z = X_train[numbers] #collect all images with number    
    numberImages[i] = np.mean(z, axis=0)

X_train = X_train.reshape(55000,784)

In [None]:
#simple plot function
def plot(samples, labels,y,x):
    fig = plt.figure(figsize=(10, 10))
    gs = gridspec.GridSpec(y, x)
    gs.update(wspace=0.05, hspace=0.05)
    plt.subplots_adjust(left=None, bottom=None, right=1, top=1.3,
                wspace=None, hspace=None)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.title(labels[i])
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')   
    return fig

In [None]:
'''
#try Diffeomorphism
image = numberImages
image = np.reshape(image, (10,28,28,1))
print(image.shape)
session = tf.InteractiveSession()
diff_map =  np.random.uniform(-0.5,0.3, size=(10,3, 3, 2)) #batch_size, diff_height, diff_width, 2
dif_image = tf_diffeomorphism(image,diff_map)
plot(image,np.zeros(10))
div_image = dif_image.eval()
plot(div_image,np.ones(10))
'''

In [None]:
#get Mean Image for Dif Input
def getMeanImage(batch_size,label):
    images = np.zeros((batch_size,784))
    for i in range(batch_size):      
        l = np.argwhere(label[i]==1)
        index = l[0][0]
        images[i]= np.reshape(numberImages[index],784)
    return images

In [None]:
#sample for Generator: Random Mean Images 0-9 + Uniform Noise Vector + Label
def sample_z(batch_size):
    batch = np.zeros((batch_size,input_dim))
    labels = np.zeros((batch_size,10)) 
    for i in range(batch_size):
        vector = np.random.uniform(-1., 1., size=[input_dim]) #create noise vector of 100
        index = int(np.random.rand(1)*10)
        labels[i][index]= 1  #get random label 
        batch[i] = vector
    return (batch, labels)

In [None]:
def convLayer(layerinput, maps, kernel, stride,scope, act="relu"):
    layerinput = slim.conv2d(layerinput, maps, kernel,stride, weights_initializer=tf.truncated_normal_initializer(stddev=1e-1),
                  scope=scope,padding = 'SAME')
    if(act == "tanh"):        
        return tf.nn.tanh(layerinput)
    else:
        layerinput = slim.batch_norm(layerinput)
        return tf.nn.relu(layerinput) 

In [None]:
#Generator Variables
scalar = tf.Variable(tf.reduce_mean(tf.random_normal([1], stddev=0.1)))
multMatrix = tf.Variable(tf.random_normal([batch_size,8,49], stddev=0.1))

#Generator - add Detail
def generator_detail(noise, dif_image):
    with tf.variable_scope('generator_detail'):
        noise = tf.reshape(noise, (batch_size,16,8))
        noise = tf.matmul(noise, multMatrix) #49x16 Matrix=784
        noise = tf.reshape(noise,(batch_size, 28,28,1))
        noise = slim.batch_norm(noise)
        noise = tf.scalar_mul(0, noise)

        details = tf.reshape(dif_image, [batch_size, 28,28,1])
        
        detail_image = tf.add(details, noise)
        detail_image = convLayer(detail_image, 64, [3,3],2, scope='convDetail_1',act="tanh")
        detail_image = convLayer(detail_image, 128, [3,3],2, scope='convDetail_2',act="tanh")
        detail_image = slim.convolution2d_transpose(detail_image, 64, [3,3],2, scope='convDetail_4')
        #detail_image = slim.batch_norm(detail_image)
        detail_image = tf.nn.tanh(detail_image)
        detail_image = slim.convolution2d_transpose(detail_image, 32, [3,3],2, scope='convDetail_5')
        #detail_image = slim.batch_norm(detail_image)
        detail_image = tf.nn.tanh(detail_image)
        detail_image = convLayer(detail_image, 1, [3,3],1, scope='convDetail_6',act="tanh") 

        return detail_image

#Generator - predict parameters for Diffeomorphism
def generator_dif_paras(noise_vector):
    with tf.variable_scope('generator_diffeo'):
        params = tf.reshape(noise_vector,[batch_size,16,8,1])
        params = convLayer(params, 32, [3, 3],2, scope='convG_1')  
        params = convLayer(params, 1, [3, 3],2, scope='convG_2')
        params = tf.nn.tanh(params)
        params= tf.reshape(params, [batch_size,2,2,2]) 
        return params
    
#Main Generator
def generator(z,class_z,mean_image):
    with tf.variable_scope('generator'):
        mean_image = tf.reshape(mean_image,[batch_size, 28,28,1])
        class_z = slim.fully_connected(class_z, input_dim,weights_initializer=tf.truncated_normal_initializer(stddev=1e-1))
        class_z = tf.nn.relu(class_z)
        class_z = tf.nn.relu(z)
        noise = tf.add(z,class_z)
        
        dif_params = generator_dif_paras(noise)
        dif_image = tf_diffeomorphism(mean_image,dif_params) 
        detailed_image = generator_detail(noise, dif_image)
        return [detailed_image, dif_image]



In [None]:
def discriminator(x):
    with tf.variable_scope('discriminator'):
        x = tf.reshape(x,[batch_size,28,28,1])
        #net = slim.conv2d(x, 128, [3, 3], weights_initializer=tf.contrib.layers.xavier_initializer(), scope='convD_1')
        #net = slim.batch_norm(net)
        #net = tf.nn.relu(net)
        net = slim.conv2d(x, 16, [3, 3],2,weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                          scope='convD_2',padding = 'SAME')
        net = tf.nn.relu(net)
        net = convLayer(net, 32, [3, 3],1,scope='convD_3')
        net = slim.convolution2d_transpose(net, 128, [3,3],2, weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
                          scope='convD_4',padding = 'SAME')
        net = slim.batch_norm(net)
        net = tf.nn.relu(net)
        net = convLayer(net, 32, [3,3],1, scope='convD_5')
        net = convLayer(net, 4, [3,3],1, scope='convD_6')
        return slim.fully_connected(net, num_outputs = 1,weights_initializer=tf.truncated_normal_initializer(stddev=0.1))


In [None]:
#Session
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)


#input
Mean_image = tf.placeholder(tf.float32, shape=[batch_size, 784])
X = tf.placeholder(tf.float32, shape=[batch_size,784]) 
Z = tf.placeholder(tf.float32, shape=[batch_size, input_dim]) #random Noise 100
Class_z = tf.placeholder(tf.float32, shape=[batch_size, 10]) #class(label) of image
print(Class_z)

#Models
detail_image = generator(Z, Class_z, Mean_image) #Generator generate image
D_real = discriminator(X)
D_fake = discriminator(detail_image[0])

#variables V1
theta_D = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') 
theta_G = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator')
theta_G_detail = sess.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'generator_detail')


D_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)
#added simple L1 Loss.
G_l1_loss = tf.multiply(0.001, tf.reduce_mean(tf.abs(detail_image[0]-detail_image[1])))

G_loss = -tf.reduce_mean(D_fake)+G_l1_loss

D_solver = (tf.train.RMSPropOptimizer(learning_rate=1e-4)
            .minimize(-D_loss, var_list=theta_D))
G_solver = (tf.train.RMSPropOptimizer(learning_rate=1e-4)
            .minimize(G_loss, var_list=theta_G))

clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in theta_D]


if not os.path.exists('out/'):
    os.makedirs('out/')

#initalize Variables    
sess.run(tf.global_variables_initializer())    

In [None]:
if not os.path.exists('logs/'):
    os.makedirs('logs/')
#Instantiate Tensorboard
#writer = tf.train.SummaryWriter("logs/", graph=tf.get_default_graph())



In [None]:
#pretrain disc:
for t in tqdm(range(500)):
    #train discriminator
    Xdata, _ = mnist.train.next_batch(batch_size)
    (z, class_z) = sample_z(batch_size) #get Image Batch+Labels
    mean_image = getMeanImage(batch_size, class_z) #get mean_image
    #print("Pretrain Step ",t,"X ",Xdata.shape,"z ",z.shape,"z_class ",class_z.shape,"mean ",mean_image.shape)
    _, D_loss_curr, _ = sess.run(
        [D_solver, D_loss, clip_D],
        feed_dict={X: Xdata, Z:z, Class_z:class_z, Mean_image:mean_image}
    ) 
    
i = 0
for it in range(100000):
    for _ in range(5): #train discriminator
        Xdata, _ = mnist.train.next_batch(batch_size)
        (z, class_z) = sample_z(batch_size) #get Image Batch+Labels
        _, D_loss_curr, _ = sess.run(
            [D_solver, D_loss, clip_D],
            feed_dict={X: Xdata, Z:z, Class_z:class_z, Mean_image:mean_image}
        )
    (z, class_z) = sample_z(batch_size) #get Image Batch+Labels
    mean_image = getMeanImage(batch_size,class_z) #get mean_image
    _, G_loss_curr = sess.run(
        [G_solver, G_loss],
        feed_dict={Z:z, Class_z:class_z,Mean_image:mean_image}
    )

    if it % 100 == 0:
        print('Iter: {}; D loss: {:.4}; G_loss: {:.4}'
              .format(it, D_loss_curr, G_loss_curr))
        if it % 1000 == 0:
            (z, class_z) = sample_z(batch_size) #get Image Batch+Labels
            mean_image = getMeanImage(batch_size,class_z) #get mean_image
            samples = sess.run(detail_image, feed_dict={Z:z, Class_z:class_z,Mean_image:mean_image})         
            samples[0] = np.reshape(samples[0], (batch_size,28,28))
            fig = plot(samples[0][:16], np.argwhere(class_z!=0)[:,1],4,4)
            plt.savefig('out/{}.png'
                        .format(str(i).zfill(3)), bbox_inches='tight')
            plt.show()
            i += 1
        else:
            (z, class_z) = sample_z(batch_size) #get Image Batch+Labels
            mean_image = getMeanImage(batch_size,class_z) #get mean_image
            samples = sess.run(detail_image, feed_dict={Z:z, Class_z:class_z,Mean_image:mean_image})         
            samples[0] = np.reshape(samples[0], (batch_size,28,28))
            fig = plot(samples[0][:4], np.argwhere(class_z!=0)[:,1],1,4)
            plt.show()
            
plt.close(fig)