In [40]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data

tf.reset_default_graph()
mnist = input_data.read_data_sets("../mnist/data/", one_hot=True)


total_epoch = 15
batch_size = 100
n_hidden = 256
n_input = 28 * 28
n_noise = 128
n_class = 10

X = tf.placeholder(tf.float32, [None, n_input])
Y = tf.placeholder(tf.float32, [None, n_class])
Z = tf.placeholder(tf.float32, [None, n_noise])


def generator(noise, labels):
    with tf.variable_scope('generator'):
        inputs = tf.concat([noise, labels], 1)
        print(inputs)
        hidden = tf.layers.dense(inputs, n_hidden, activation=tf.nn.relu, name="gen_hidden")
        
        output = tf.layers.dense(hidden, n_input, activation=tf.nn.sigmoid)

    return output


def discriminator(inputs, labels, reuse=None):
    with tf.variable_scope('discriminator') as scope:
        if reuse:
            scope.reuse_variables()

        inputs = tf.concat([inputs, labels], 1)

        hidden = tf.layers.dense(inputs, n_hidden, activation=tf.nn.relu, name="dis_hidden")
        output = tf.layers.dense(hidden, 1, activation=None)

    return output


def get_noise(batch_size, n_noise):
    return np.random.uniform(-1., 1., size=[batch_size, n_noise])

G = generator(Z, Y)
D_real = discriminator(X, Y)
D_gene = discriminator(G, Y, True)

loss_D_real = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=D_real, labels=tf.ones_like(D_real)))
loss_D_gene = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=D_gene, labels=tf.zeros_like(D_gene)))

loss_D = loss_D_real + loss_D_gene
loss_G = tf.reduce_mean(
                    tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=D_gene, labels=tf.ones_like(D_gene)))


tf.summary.scalar('costD', loss_D)
tf.summary.scalar('costG', loss_G)


vars_D = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                           scope='discriminator')

vars_G = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                           scope='generator')

train_D = tf.train.AdamOptimizer().minimize(loss_D,
                                            var_list=vars_D)
train_G = tf.train.AdamOptimizer().minimize(loss_G,
                                            var_list=vars_G)

tf.summary.histogram('vars_D', vars_D[0])
tf.summary.histogram('vars_G', vars_G[0])

sess = tf.Session()
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter('./logs', sess.graph)
sess.run(tf.global_variables_initializer())



total_batch = int(mnist.train.num_examples/batch_size)
loss_val_D, loss_val_G = 0, 0

for epoch in range(total_epoch):
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        noise = get_noise(batch_size, n_noise)

        m, loss_val_D, _  = sess.run([merged, loss_D, train_D],
                                 feed_dict={X: batch_xs, Y: batch_ys, Z: noise})

        loss_val_G, _  = sess.run([loss_G, train_G],
                                 feed_dict={Y: batch_ys, Z: noise})
        writer.add_summary(m, i + epoch * total_batch)
        
#         _, _, loss_val_G, loss_val_D  = sess.run([train_G, train_D, loss_G, loss_D], 
#                                                    feed_dict={X: batch_xs, Y: batch_ys, Z: noise})
        
        

    print('Epoch:', '%04d' % epoch,
          'D loss: {:.4}'.format(loss_val_D),
          'G loss: {:.4}'.format(loss_val_G))

    if epoch == 0 or (epoch + 1) % 10 == 0:
        sample_size = 10
        noise = get_noise(sample_size, n_noise)
        samples = sess.run(G,
                           feed_dict={Y: mnist.test.labels[:sample_size],
                                      Z: noise})

        fig, ax = plt.subplots(2, sample_size, figsize=(sample_size, 2))
        
        for i in range(sample_size):
            ax[0][i].set_axis_off()
            ax[1][i].set_axis_off()

            ax[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
            ax[1][i].imshow(np.reshape(samples[i], (28, 28)))

        plt.savefig('samples2/{}.png'.format(str(epoch).zfill(3)), bbox_inches='tight')
        plt.close(fig)

print('최적화 완료!')

Extracting ../mnist/data/train-images-idx3-ubyte.gz
Extracting ../mnist/data/train-labels-idx1-ubyte.gz
Extracting ../mnist/data/t10k-images-idx3-ubyte.gz
Extracting ../mnist/data/t10k-labels-idx1-ubyte.gz
Tensor("generator/concat:0", shape=(?, 138), dtype=float32)
Epoch: 0000 D loss: 0.002837 G loss: 6.606
Epoch: 0001 D loss: 0.008663 G loss: 9.067
Epoch: 0002 D loss: 0.006761 G loss: 8.978
Epoch: 0003 D loss: 0.001024 G loss: 8.86
Epoch: 0004 D loss: 0.0003957 G loss: 9.568
Epoch: 0005 D loss: 0.0002195 G loss: 9.774
Epoch: 0006 D loss: 0.0005519 G loss: 12.41
Epoch: 0007 D loss: 3.45e-05 G loss: 11.2
Epoch: 0008 D loss: 0.0001329 G loss: 11.49
Epoch: 0009 D loss: 9.401e-05 G loss: 14.4
Epoch: 0010 D loss: 6.054e-05 G loss: 24.28
Epoch: 0011 D loss: 0.0005868 G loss: 15.19
Epoch: 0012 D loss: 0.001249 G loss: 9.476
Epoch: 0013 D loss: 0.0004511 G loss: 9.549
Epoch: 0014 D loss: 3.248e-05 G loss: 12.67
최적화 완료!
