In [30]:
import numpy as np
import tensorflow as tf 

from tensorflow.contrib.layers import fully_connected
def reset_graph(seed=42):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)
reset_graph()
#该函数将给出权重初始化的方法
def variable_init(size):
    in_dim = size[0]

    #计算随机生成变量所服从的正态分布标准差
    w_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=w_stddev)

#定义输入矩阵的占位符，输入层单元为784，None代表批量大小的占位，X代表输入的真实图片。占位符的数值类型为32位浮点型
X = tf.placeholder(tf.float32, shape=[None, 784])

#定义判别器的权重矩阵和偏置项向量，由此可知判别网络为三层全连接网络
D_W1 = tf.Variable(variable_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))

D_W2 = tf.Variable(variable_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]

#定义生成器的输入噪声为100维度的向量组，None根据批量大小确定
Z = tf.placeholder(tf.float32, shape=[None, 100])

#定义生成器的权重与偏置项。输入层为100个神经元且接受随机噪声，
#输出层为784个神经元，并输出手写字体图片。生成网络根据原论文为三层全连接网络
G_W1 = tf.Variable(variable_init([100, 128]))
G_b1 = tf.Variable(tf.zeros(shape=[128]))

G_W2 = tf.Variable(variable_init([128, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))

theta_G = [G_W1, G_W2, G_b1, G_b2]

def generator(Z):
    # g_latent_out = dnn(Z, pre_name="g_")
    # g_logits = tf.layers.dense(g_latent_out, n_G_outputs, kernel_initializer=he_init, name="g_logits")
    # g_proba = tf.nn.sigmoid(g_logits, name="g_proba")
    G_h1 = tf.nn.relu(tf.matmul(Z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)
    g_logits  = G_log_prob
    g_proba = G_prob
    # return g_logits, g_proba
    return G_prob

def discriminator(X, reuse=None):
    # d_latent_out = dnn(X, pre_name="d_",reuse=reuse)
    # d_logits = tf.layers.dense(d_latent_out, n_D_outputs, kernel_initializer=he_init, name="d_logits", reuse=reuse)
    # d_proba = tf.nn.sigmoid(d_logits)
    D_h1 = tf.nn.relu(tf.matmul(X, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)
    d_logits = D_logit
    d_proba = D_prob
    # return d_logits, d_proba
    return D_prob, D_logit

#输入随机噪声z而输出生成样本
g_proba = generator(Z)

#分别输入真实图片和生成的图片，并投入判别器以判断真伪
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(g_proba)

#以下为原论文的判别器损失和生成器损失，但本实现并没有使用该损失函数
# D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# G_loss = -tf.reduce_mean(tf.log(D_fake))

# 我们使用交叉熵作为判别器和生成器的损失函数，因为sigmoid_cross_entropy_with_logits内部会对预测输入执行Sigmoid函数，
#所以我们取判别器最后一层未投入激活函数的值，即D_h1*D_W2+D_b2。
#tf.ones_like(D_logit_real)创建维度和D_logit_real相等的全是1的标注，真实图片。
loss_data = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))

#损失函数为两部分，即E[log(D(x))]+E[log(1-D(G(z)))]，将真的判别为假和将假的判别为真
loss_D = loss_data + loss_g

#同样使用交叉熵构建生成器损失函数
loss_G = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

#定义判别器和生成器的优化方法为Adam算法，关键字var_list表明最小化损失函数所更新的权重矩阵
training_op_d = tf.train.AdamOptimizer().minimize(loss_D, var_list=theta_D)
training_op_g = tf.train.AdamOptimizer().minimize(loss_G, var_list=theta_G)

saver = tf.train.Saver()
init = tf.global_variables_initializer()


n_epochs = 20000
batch_size = 128
z_dim = 100

checkpoint_path = "/tmp/my_mnistGAN_model.ckpt"
checkpoint_epoch_path = checkpoint_path + ".epoch"
final_model_path = "./my_mnistGAN_model"

import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

if not os.path.exists('out/'):
    os.makedirs('out/')
    
#CREATE PHOTO
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig


In [31]:
with tf.Session() as sess:
    sess.run(init)
    # if not os.path.exists(checkpoint_epoch_path):
    #     start_epoch = 0
    #     sess.run(init)
    # else:
    #     with open(checkpoint_epoch_path, 'rb') as f:
    #         start_epoch = int(f.read())
    #     print("Training was interrupted. Continuing at epoch", start_epoch)
    #     saver.restore(sess, checkpoint_path)
    i = 0 
    # for epoch in range(start_epoch, n_epochs):
    for epoch in range(n_epochs):
        
        
        X_batch, y_batch = mnist.train.next_batch(batch_size)
        # Z_batch = sample_Z(batch_size, z_dim)
        _, loss_D_val = sess.run([training_op_d, loss_D], feed_dict={X: X_batch, Z: sample_Z(batch_size, z_dim)})
        _, loss_G_val = sess.run([training_op_g, loss_G], feed_dict={Z: sample_Z(batch_size, z_dim)})
        
        if epoch % 2000 == 0:
            samples = sess.run(g_proba, feed_dict={Z: sample_Z(16, z_dim)})
            fig = plot(samples)
            plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
            i += 1
            plt.close(fig)
            
        # if epoch % 5000 == 0:
        #     with open(checkpoint_epoch_path, 'wb') as f:
        #         saver.save(sess, checkpoint_path)
        #         f.write(b'%d' % (epoch + 1))
        
        if epoch % 2000 == 0:
            print('Iter: {}'.format(epoch))
            print('D loss: {:.4}'. format(loss_D_val))
            print('G_loss: {:.4}'.format(loss_G_val))

Iter: 0
D loss: 1.603
G_loss: 2.079


Iter: 2000
D loss: 0.1213
G_loss: 3.874


Iter: 4000
D loss: 0.2483
G_loss: 4.692


Iter: 6000
D loss: 0.3041
G_loss: 4.098


Iter: 8000
D loss: 0.5982
G_loss: 2.967


Iter: 10000
D loss: 0.511
G_loss: 3.484


Iter: 12000
D loss: 0.6221
G_loss: 2.95


Iter: 14000
D loss: 0.7866
G_loss: 2.576


Iter: 16000
D loss: 0.5479
G_loss: 2.429


Iter: 18000
D loss: 0.774
G_loss: 2.441


In [45]:
import os

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/")

he_init = tf.contrib.layers.variance_scaling_initializer()

n_D_inputs = 28*28
n_G_inputs = 100
n_D_outputs = 1
n_G_outputs = 28*28

n_epochs = 20000
batch_size = 128
z_dim = 100

checkpoint_path = "/tmp/my_mnistGAN_model.ckpt"
checkpoint_epoch_path = checkpoint_path + ".epoch"
final_model_path = "./my_mnistGAN_model"

def reset_graph(seed=42):
    tf.reset_default_graph()
    tf.set_random_seed(seed)
    np.random.seed(seed)

def dnn(inputs, n_hidden_layers=1, n_neurons=128, pre_name=None, activation=tf.nn.elu,
        initializer=he_init, reuse=None):
    if pre_name == None:
        pre_name = ""
    with tf.name_scope("dnn"):
        for layer in range(n_hidden_layers):
            inputs = tf.layers.dense(inputs, n_neurons, activation=activation, \
                    kernel_initializer=initializer, name="%shidden%d" % (pre_name, layer + 1), reuse=reuse)
    return inputs

def generator(Z):
    g_latent_out = dnn(Z, pre_name="g_")
    g_logits = tf.layers.dense(g_latent_out, n_G_outputs, kernel_initializer=he_init, name="g_logits")
    g_proba = tf.nn.sigmoid(g_logits, name="g_proba")
    return g_logits, g_proba

def discriminator(X, reuse=None):
    d_latent_out = dnn(X, pre_name="d_",reuse=reuse)
    d_logits = tf.layers.dense(d_latent_out, n_D_outputs, kernel_initializer=he_init, name="d_logits", reuse=reuse)
    d_proba = tf.nn.sigmoid(d_logits)
    return d_logits, d_proba

#CREATE PHOTO
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

def main():
    reset_graph()
    Z = tf.placeholder(tf.float32, shape=(None, n_G_inputs), name="Z")
    X = tf.placeholder(tf.float32, shape=(None, n_D_inputs), name="X")

    g_logits, g_proba = generator(Z)
    d_logits_data, d_proba_data = discriminator(X)

    d_logits_g, d_proba_g = discriminator(g_proba, reuse=True)

    Variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    theta_G = Variables[:4]
    theta_D = Variables[4:]

    xentropy_g = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_g, labels=tf.zeros_like(d_logits_g))
    loss_g = tf.reduce_mean(xentropy_g, name="loss_g")
    xentropy_data = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_data, labels=tf.ones_like(d_logits_data))
    loss_data = tf.reduce_mean(xentropy_data, name="loss_data")
    loss_D = loss_data + loss_g
    xentropy_G = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_g, labels=tf.ones_like(d_logits_g))
    loss_G = tf.reduce_mean(xentropy_G, name="loss_G")

    training_op_d = tf.train.AdamOptimizer().minimize(loss_D, var_list=theta_D)
    training_op_g = tf.train.AdamOptimizer().minimize(loss_G, var_list=theta_G)

    saver = tf.train.Saver()
    init = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init)

        if not os.path.exists('out/'):
            os.makedirs('out/')

        # if not os.path.exists(checkpoint_epoch_path):
        #     start_epoch = 0
        #     sess.run(init)
        # else:
        #     with open(checkpoint_epoch_path, 'rb') as f:
        #         start_epoch = int(f.read())
        #     print("Training was interrupted. Continuing at epoch", start_epoch)
        #     saver.restore(sess, checkpoint_path)
        i = 0
        # for epoch in range(start_epoch, n_epochs):
        for epoch in range(n_epochs):

            X_batch, y_batch = mnist.train.next_batch(batch_size)
            # Z_batch = sample_Z(batch_size, z_dim)
            _, loss_D_val = sess.run([training_op_d, loss_D], feed_dict={X: X_batch, Z: sample_Z(batch_size, z_dim)})
            _, loss_G_val = sess.run([training_op_g, loss_G], feed_dict={Z: sample_Z(batch_size, z_dim)})

            if epoch % 2000 == 0:
                samples = sess.run(g_proba, feed_dict={Z: sample_Z(16, z_dim)})
                fig = plot(samples)
                plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
                i += 1
                plt.close(fig)

            # if epoch % 5000 == 0:
            #     with open(checkpoint_epoch_path, 'wb') as f:
            #         saver.save(sess, checkpoint_path)
            #         f.write(b'%d' % (epoch + 1))

            if epoch % 2000 == 0:
                print('Iter: {}'.format(epoch))
                print('D loss: {:.4}'.format(loss_D_val))
                print('G_loss: {:.4}'.format(loss_G_val))

Extracting /tmp/data/train-images-idx3-ubyte.gz


Extracting /tmp/data/train-labels-idx1-ubyte.gz
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [46]:
main()

Iter: 0
D loss: 1.875
G_loss: 2.783


Iter: 2000
D loss: 0.2635
G_loss: 3.857


Iter: 4000
D loss: 0.5207
G_loss: 3.113


Iter: 6000
D loss: 0.4477
G_loss: 3.209


Iter: 8000
D loss: 0.2436
G_loss: 3.115


Iter: 10000
D loss: 0.3834
G_loss: 2.934


Iter: 12000
D loss: 0.3561
G_loss: 2.891


Iter: 14000
D loss: 0.4913
G_loss: 2.904


Iter: 16000
D loss: 0.6247
G_loss: 2.594


Iter: 18000
D loss: 0.5452
G_loss: 2.681
