In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import time
import matplotlib.pyplot as plt
import itertools, imageio, pickle
tf.logging.set_verbosity(tf.logging.ERROR)

In [None]:
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True, reshape=[])

class DataSampler(object):
    def __init__(self, dataset):
        self.dataset = dataset
        self.num_sample = dataset.shape[0]
        self.index_in_epoch = 0
        self.epochs_completed = 0
        
    def __call__(self, batch_size,shuffle=True):
        start = self.index_in_epoch
        if self.epochs_completed == 0 and start == 0 and shuffle:
            perm = np.arange(self.num_sample)
            np.random.shuffle(perm)
            self.dataset = self.dataset[perm]
        
        if start + batch_size > self.num_sample:
            self.epochs_completed += 1
            rest_num_examples = self.num_sample - start
            data_rest_part = self.dataset[start:self.num_sample]
            
            if shuffle:
                prem = np.arange(self.num_sample)
                np.random.shuffle(prem)
                self.dataset = self.dataset[prem]
                
            start = 0
            self.index_in_epoch = batch_size - rest_num_examples
            end = self.index_in_epoch
            data_new_part = self.dataset[start:end]
            return np.concatenate((data_rest_part,data_new_part),axis=0 )
        else:
            self.index_in_epoch += batch_size
            end = self.index_in_epoch
            return self.dataset[start:end]
        
class NoiseSampler(object):
    def __call__(self, batch_size, z_dim):
        return np.random.normal(0., 1.0, [batch_size,1,1,z_dim])

In [None]:
def lrelu(x, th=0.2):
    return tf.maximum(th * x, x)

def layer_norm(x, trainable=True, name='layer_norm'):
    with tf.variable_scope(name):
        return tf.contrib.layers.layer_norm(x, trainable=trainable)

# G(z)
def generator(x, isTrain=True, reuse=False):
    with tf.variable_scope('generator', reuse=reuse):

        # 1st hidden layer
        conv1 = tf.layers.conv2d_transpose(x, 1024, [4, 4], strides=(1, 1), padding='valid')
        lrelu1 = lrelu(tf.layers.batch_normalization(conv1, training=isTrain), 0.2)

        # 2nd hidden layer
        conv2 = tf.layers.conv2d_transpose(lrelu1, 512, [4, 4], strides=(2, 2), padding='same')
        lrelu2 = lrelu(tf.layers.batch_normalization(conv2, training=isTrain), 0.2)

        # 3rd hidden layer
        conv3 = tf.layers.conv2d_transpose(lrelu2, 256, [4, 4], strides=(2, 2), padding='same')
        lrelu3 = lrelu(tf.layers.batch_normalization(conv3, training=isTrain), 0.2)

        # 4th hidden layer
        conv4 = tf.layers.conv2d_transpose(lrelu3, 128, [4, 4], strides=(2, 2), padding='same')
        lrelu4 = lrelu(tf.layers.batch_normalization(conv4, training=isTrain), 0.2)

        # output layer
        conv5 = tf.layers.conv2d_transpose(lrelu4, 1, [4, 4], strides=(2, 2), padding='same')
        o = tf.nn.tanh(conv5)

        return o
    
def discriminator(x, reuse=False):
    with tf.variable_scope('discriminator', reuse=reuse):
        # 1st hidden layer
        conv1 = tf.layers.conv2d(x, 128, [4, 4], strides=(2, 2), padding='same')
        lrelu1 = lrelu(conv1, 0.2)

        # 2nd hidden layer
        conv2 = tf.layers.conv2d(lrelu1, 256, [4, 4], strides=(2, 2), padding='same')
        lrelu2 = lrelu(layer_norm(conv2,name='ln1'), 0.2)

        # 3rd hidden layer
        conv3 = tf.layers.conv2d(lrelu2, 512, [4, 4], strides=(2, 2), padding='same')
        lrelu3 = lrelu(layer_norm(conv3,name='ln2'), 0.2)

        # 4th hidden layer
        conv4 = tf.layers.conv2d(lrelu3, 1024, [4, 4], strides=(2, 2), padding='same')
        lrelu4 = lrelu(layer_norm(conv4,name='ln3'), 0.2)

        # output layer
        conv5 = tf.layers.conv2d(lrelu4, 1, [4, 4], strides=(1, 1), padding='valid')
        
        return conv5    

In [None]:
fixed_z_ = np.random.normal(0, 1, (25, 1, 1, 100))
def show_result(num_epoch, show = False, save = False, path = 'result.png'):
    test_images = sess.run(Gz, {z: fixed_z_, isTrain: False})

    size_figure_grid = 5
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)

    for k in range(size_figure_grid*size_figure_grid):
        i = k // size_figure_grid
        j = k % size_figure_grid
        ax[i, j].cla()
        ax[i, j].imshow(np.reshape(test_images[k], (64, 64)), cmap='gray')

    label = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

In [None]:
x = tf.placeholder(tf.float32,[None, 64,64,1])
z = tf.placeholder(tf.float32, shape=(None, 1, 1, 100))
isTrain = tf.placeholder(dtype=tf.bool)

Gz = generator(z, isTrain)

D_real = discriminator(x)
D_fake = discriminator(Gz,reuse=True)

g_loss = tf.reduce_mean(D_fake)
d_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)

epsilon = tf.random_uniform([],0.,1.0)
x_hat = epsilon*x + (1 - epsilon)*Gz

d_hat = discriminator(x_hat, reuse=True)

ddx = tf.gradients(d_hat,x_hat)[0]
ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))
ddx = tf.reduce_mean(tf.square(ddx - 1.0)*10.)

d_loss = d_loss + ddx

# trainable variables for each network
T_vars = tf.trainable_variables()
D_vars = [var for var in T_vars if var.name.startswith('discriminator')]
print(D_vars)
print("=========================================")
G_vars = [var for var in T_vars if var.name.startswith('generator')]
print(G_vars)

# 此处应该加tf.control_dependencies来同步bacth_norm,后续研究
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS,scope='discriminator')):
    D_opt = tf.train.AdamOptimizer(learning_rate=1e-4,beta1=0.,beta2=0.9).minimize(d_loss,var_list=D_vars)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS,scope='generator')):
    G_opt = tf.train.AdamOptimizer(learning_rate=1e-4,beta1=0.,beta2=0.9).minimize(g_loss,var_list=G_vars)

In [None]:
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))

tf.global_variables_initializer().run()

x_train= tf.image.resize_images(mnist.train.images, [64, 64]).eval()
x_train =(x_train - 0.5)/0.5

x_sampler = DataSampler(x_train)
z_sampler = NoiseSampler()


# results save folder
root = 'MNIST_WDCGAN-GP_results/'
model = 'MNIST_WDCGAN-GP_'
if not os.path.isdir(root):
    os.mkdir(root)
if not os.path.isdir(root + 'Fixed_results'):
    os.mkdir(root + 'Fixed_results')
    
# training-loop
np.random.seed(int(time.time()))
print("training start!")
start_time = time.time()

batch_size = 64
num_iters = 100000
for iter in range(0,num_iters):
    ncritic = 5
    if iter % 500 == 0 or iter < 25:
        ncritic = 100
    for _ in range(0, ncritic):
        x_batch = x_sampler(batch_size)
        z_batch = z_sampler(batch_size,100)
        sess.run(D_opt, feed_dict={x:x_batch,z:z_batch, isTrain:True})
        
        
    z_batch = z_sampler(batch_size, 100)
    sess.run(G_opt, feed_dict = {x:x_batch, z:z_batch, isTrain:True})
    
    if iter % 100 ==0:
        D_loss = sess.run(d_loss,feed_dict={x:x_batch,z:z_batch,isTrain:True})
        G_loss = sess.run(g_loss,feed_dict={x:x_batch,z:z_batch,isTrain:True})
        print("Iter: {},d_loss: {}, g_loss: {}".format(iter,D_loss,G_loss))
    if iter % 500 == 0:
        fixed_p = root + 'Fixed_results/' + model + str(iter) + '.png'
        show_result((iter + 1), save=True, path=fixed_p)
        