In [1]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from ops import dense, conv2d, conv2d_transpose
import numpy as np

In [2]:
batch_size = 64

keep_prob = tf.placeholder(dtype=tf.float32)
z = tf.placeholder(dtype=tf.float32, shape=[None, 100])
x = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1])


def build_generator(input_):
    with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
        dense_1 = dense(input_, 7 * 7 * 64, activation=tf.nn.relu, name='dense_1')
        drop_1 = tf.nn.dropout(dense_1, keep_prob)

        reshape_1 = tf.reshape(drop_1, shape=[-1, 7, 7, 64])

        deconv_1 = conv2d_transpose(reshape_1, batch_size, 14, 14, 5, 32, strides=[1, 2, 2, 1], activation=tf.nn.relu, name='deconv_1')
        drop_2 = tf.nn.dropout(deconv_1, keep_prob)

        deconv_2 = conv2d_transpose(drop_2, batch_size, 28, 28, 5, 1, strides=[1, 2, 2, 1], activation=tf.nn.relu, name='deconv_2')
        drop_3 = tf.nn.dropout(deconv_2, keep_prob)

        return drop_3


def build_discriminator(input_):
    with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
        conv_1 = conv2d(input_, 5, 32, activation=tf.nn.relu, name='conv_1')
        pool_1 = tf.nn.avg_pool(conv_1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
        drop_1 = tf.nn.dropout(pool_1, keep_prob)

        conv_2 = conv2d(drop_1, 5, 64, activation=tf.nn.relu, name='conv_2')
        pool_2 = tf.nn.avg_pool(conv_2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
        drop_2 = tf.nn.dropout(pool_2, keep_prob)

        flatten_1 = tf.reshape(drop_2, shape=[-1, 7 * 7 * 64])
        dense_1 = dense(flatten_1, 1024, activation=tf.nn.relu, name='dense_1')
        drop_3 = tf.nn.dropout(dense_1, keep_prob)

        dense_2 = dense(drop_3, 1, name='dense_2')
        return dense_2

d_x = build_discriminator(x)
d_z = build_discriminator(build_generator(z))




In [3]:
d_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator')
g_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')

loss_d = tf.reduce_mean(- d_x) + tf.reduce_mean(d_z)
loss_g = tf.reduce_mean(- d_z)

d_minimizer = tf.train.RMSPropOptimizer(0.00005).minimize(loss_d, var_list=d_var)
g_minimizer = tf.train.RMSPropOptimizer(0.00005).minimize(loss_g, var_list=g_var)


mnist = input_data.read_data_sets('MNIST')

Extracting MNIST\train-images-idx3-ubyte.gz
Extracting MNIST\train-labels-idx1-ubyte.gz
Extracting MNIST\t10k-images-idx3-ubyte.gz
Extracting MNIST\t10k-labels-idx1-ubyte.gz


In [4]:
random_normal = tf.random_normal([batch_size, 100])
c= 0.01
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for i in range(1000):
        for j in range(10):
            
            batch_x, _ = mnist.train.next_batch(batch_size)
            batch_x = np.reshape(batch_x, [-1, 28,28,1])
            batch_z = sess.run(random_normal)
            sess.run(d_minimizer, feed_dict={x:batch_x, z:batch_z , keep_prob:0.5})
            clip = [v.assign(tf.clip_by_value(v, -c, c)) for v in d_var]
            sess.run(clip)
                
        
        sess.run(g_minimizer , feed_dict={z:batch_z, keep_prob:0.5})
#         if i%100==0:
        print(sess.run([loss_d, loss_g], feed_dict={x:batch_x, z:batch_z , keep_prob:0.5}))

[1.2156246, -1.2084754]
[1.1652998, -1.1573792]
[1.9606456, -1.953351]
[0.70705998, -0.69952548]
[0.78502476, -0.7769987]
[0.51420969, -0.50635535]
[-0.4649691, 0.47308618]
[-0.74141061, 0.74881399]
[-3.7664378, 3.7755303]
[-4.8016748, 4.8105597]
[-6.5179005, 6.5265741]
[-9.067605, 9.0757351]
[-11.585329, 11.593972]
[-15.238391, 15.246655]
[-19.731878, 19.74033]
[-28.409479, 28.419292]
[-45.314129, 45.325516]
[-75.624199, 75.63958]
[-141.33752, 141.36165]
[-279.70767, 279.74719]
[-547.10126, 547.16699]
[-963.03308, 963.13501]
[-1554.9847, 1555.1295]
[-2214.0955, 2214.28]
[-3028.7905, 3029.0259]
[-3871.6929, 3871.9888]
[-4995.6895, 4996.0215]
[-6043.9331, 6044.3115]
[-7540.7192, 7541.1533]
[-8505.1895, 8505.6797]
[-10393.359, 10393.871]
[-11626.28, 11626.838]
[-13511.719, 13512.336]
[-15528.387, 15529.027]
[-17076.723, 17077.41]
[-19205.6, 19206.391]
[-22246.488, 22247.336]
[-24097.969, 24098.879]
[-26795.908, 26796.891]
[-29464.865, 29465.943]
[-33117.293, 33118.438]
[-35880.555, 35881

KeyboardInterrupt: 

In [None]:
random_normal