In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import math

  return f(*args, **kwds)


# Real data generation

In [2]:
def get_y(x):
    return 10 + math.cos(x*x / 200) + 0.6 * math.sin(0.6 * x) + x / 50


def sample_data(n=10000, scale=100):
    data = []

    x = scale*(np.random.random_sample((n,))-0.5)

    for i in range(n):
        yi = get_y(x[i])
        data.append([x[i], yi])

    return np.array(data)

In [3]:
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

def generator(Z,hsize=[96, 96],reuse=False):
    with tf.variable_scope("GAN/Generator",reuse=reuse):
        h1 = tf.layers.dense(Z,hsize[0],activation=tf.nn.leaky_relu)
        h2 = tf.layers.dense(h1,hsize[1],activation=tf.nn.leaky_relu)
        out = tf.layers.dense(h2,2)

    return out

def discriminator(X,hsize=[96, 96],reuse=False):
    with tf.variable_scope("GAN/Discriminator",reuse=reuse):
        h1 = tf.layers.dense(X,hsize[0],activation=tf.nn.leaky_relu)
        h2 = tf.layers.dense(h1,hsize[1],activation=tf.nn.leaky_relu)
        h3 = tf.layers.dense(h2,2)
        out = tf.layers.dense(h3,1)

    return out, h3

In [4]:
X = tf.placeholder(tf.float32,[None,2])
Z = tf.placeholder(tf.float32,[None,2])

G_sample = generator(Z)
r_logits, r_rep = discriminator(X)
f_logits, g_rep = discriminator(G_sample,reuse=True)

disc_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=r_logits,labels=tf.ones_like(r_logits)) + tf.nn.sigmoid_cross_entropy_with_logits(logits=f_logits,labels=tf.zeros_like(f_logits)))
gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=f_logits,labels=tf.ones_like(f_logits)))

gen_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="GAN/Generator")
disc_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="GAN/Discriminator")

gen_step = tf.train.RMSPropOptimizer(learning_rate=0.001).minimize(gen_loss,var_list = gen_vars) # G Train step
disc_step = tf.train.RMSPropOptimizer(learning_rate=0.001).minimize(disc_loss,var_list = disc_vars) # D Train step

In [5]:
# sess = tf.Session(config=config)
sess = tf.Session()
tf.global_variables_initializer().run(session=sess)

In [6]:
batch_size = 256
nd_steps = 10
ng_steps = 10

In [7]:
x_plot = sample_data(n=batch_size)

In [8]:
f = open('loss_logs.csv','w')
f.write('Iteration,Discriminator Loss,Generator Loss\n')

44

In [9]:
for i in range(80001):
    X_batch = sample_data(n=batch_size)
    Z_batch = sample_Z(batch_size, 2)

    for _ in range(nd_steps):
        _, dloss = sess.run([disc_step, disc_loss], feed_dict={X: X_batch, Z: Z_batch})
    rrep_dstep, grep_dstep = sess.run([r_rep, g_rep], feed_dict={X: X_batch, Z: Z_batch})

    for _ in range(ng_steps):
        _, gloss = sess.run([gen_step, gen_loss], feed_dict={Z: Z_batch})

    rrep_gstep, grep_gstep = sess.run([r_rep, g_rep], feed_dict={X: X_batch, Z: Z_batch})

    
    if i%200 == 0:
        print ("Iterations: %d\t Discriminator loss: %.4f\t Generator loss: %.4f " %(i,dloss,gloss))
        f.write("%d,%f,%f\n"%(i,dloss,gloss))

    if i%1000 == 0:
        plt.figure()
        g_plot = sess.run(G_sample, feed_dict={Z: Z_batch})
        xax = plt.scatter(x_plot[:,0], x_plot[:,1])
        gax = plt.scatter(g_plot[:,0],g_plot[:,1])

        plt.legend((xax,gax), ("Real Data","Generated Data"))
        plt.title('Samples at Iteration %d'%i)
        plt.tight_layout()
        plt.savefig('iteration_%d.png'%i)
        plt.close()

        plt.figure()
        rrd = plt.scatter(rrep_dstep[:,0], rrep_dstep[:,1], alpha=0.5)
        rrg = plt.scatter(rrep_gstep[:,0], rrep_gstep[:,1], alpha=0.5)
        grd = plt.scatter(grep_dstep[:,0], grep_dstep[:,1], alpha=0.5)
        grg = plt.scatter(grep_gstep[:,0], grep_gstep[:,1], alpha=0.5)


        plt.legend((rrd, rrg, grd, grg), ("Real Data Before G step","Real Data After G step",
                               "Generated Data Before G step","Generated Data After G step"))
        plt.title('Transformed Features at Iteration %d'%i)
        plt.tight_layout()
        plt.savefig('feature_transform_%d.png'%i)
        plt.close()

        plt.figure()

Iterations: 0	 Discriminator loss: 0.8785	 Generator loss: 0.6918 
Iterations: 200	 Discriminator loss: 1.2279	 Generator loss: 0.4187 
Iterations: 400	 Discriminator loss: 1.2044	 Generator loss: 0.5358 
Iterations: 600	 Discriminator loss: 1.1592	 Generator loss: 0.3710 
Iterations: 800	 Discriminator loss: 1.0999	 Generator loss: 0.3123 
Iterations: 1000	 Discriminator loss: 1.0432	 Generator loss: 0.2338 
Iterations: 1200	 Discriminator loss: 0.9898	 Generator loss: 0.2661 
Iterations: 1400	 Discriminator loss: 1.1071	 Generator loss: 0.1852 
Iterations: 1600	 Discriminator loss: 0.8000	 Generator loss: 0.2095 
Iterations: 1800	 Discriminator loss: 0.9775	 Generator loss: 0.2245 
Iterations: 2000	 Discriminator loss: 0.7400	 Generator loss: 0.2884 
Iterations: 2200	 Discriminator loss: 1.0551	 Generator loss: 0.3244 
Iterations: 2400	 Discriminator loss: 1.0606	 Generator loss: 0.2895 
Iterations: 2600	 Discriminator loss: 0.9423	 Generator loss: 0.2865 
Iterations: 2800	 Discrimin



Iterations: 20200	 Discriminator loss: 1.1912	 Generator loss: 0.3530 
Iterations: 20400	 Discriminator loss: 1.1803	 Generator loss: 0.4094 
Iterations: 20600	 Discriminator loss: 0.9357	 Generator loss: 0.3319 
Iterations: 20800	 Discriminator loss: 1.0165	 Generator loss: 0.2240 
Iterations: 21000	 Discriminator loss: 1.0040	 Generator loss: 0.3204 
Iterations: 21200	 Discriminator loss: 0.8166	 Generator loss: 0.3443 
Iterations: 21400	 Discriminator loss: 0.7333	 Generator loss: 0.2385 
Iterations: 21600	 Discriminator loss: 1.0687	 Generator loss: 0.4492 
Iterations: 21800	 Discriminator loss: 0.5302	 Generator loss: 0.1552 
Iterations: 22000	 Discriminator loss: 0.8801	 Generator loss: 0.2708 
Iterations: 22200	 Discriminator loss: 1.3397	 Generator loss: 0.4612 
Iterations: 22400	 Discriminator loss: 0.9379	 Generator loss: 0.2511 
Iterations: 22600	 Discriminator loss: 0.9358	 Generator loss: 0.9745 
Iterations: 22800	 Discriminator loss: 0.7984	 Generator loss: 0.2496 
Iterat

<matplotlib.figure.Figure at 0x7f087c1acd30>

<matplotlib.figure.Figure at 0x7f087c113eb8>

<matplotlib.figure.Figure at 0x7f087c1684e0>

<matplotlib.figure.Figure at 0x7f087c06add8>

<matplotlib.figure.Figure at 0x7f0874f84940>

<matplotlib.figure.Figure at 0x7f0874d52e48>

<matplotlib.figure.Figure at 0x7f0874fa5e10>

<matplotlib.figure.Figure at 0x7f0874c957b8>

<matplotlib.figure.Figure at 0x7f0874bef668>

<matplotlib.figure.Figure at 0x7f0874a26e48>

<matplotlib.figure.Figure at 0x7f0874967978>

<matplotlib.figure.Figure at 0x7f08749ecac8>

<matplotlib.figure.Figure at 0x7f087488d3c8>

<matplotlib.figure.Figure at 0x7f08747db2b0>

<matplotlib.figure.Figure at 0x7f087459f518>

<matplotlib.figure.Figure at 0x7f087447fd68>

<matplotlib.figure.Figure at 0x7f087456de48>

<matplotlib.figure.Figure at 0x7f08743be9b0>

<matplotlib.figure.Figure at 0x7f0874320e80>

<matplotlib.figure.Figure at 0x7f08742aafd0>

<matplotlib.figure.Figure at 0x7f087418a6a0>

<matplotlib.figure.Figure at 0x7f08740de6a0>

<matplotlib.figure.Figure at 0x7f0841eab978>

<matplotlib.figure.Figure at 0x7f0841ed0048>

<matplotlib.figure.Figure at 0x7f0841ccbcf8>

<matplotlib.figure.Figure at 0x7f0841d02cf8>

<matplotlib.figure.Figure at 0x7f0841b2ffd0>

<matplotlib.figure.Figure at 0x7f0841b42748>

<matplotlib.figure.Figure at 0x7f0841967f98>

<matplotlib.figure.Figure at 0x7f08418b5518>

<matplotlib.figure.Figure at 0x7f084186df28>

<matplotlib.figure.Figure at 0x7f0841804dd8>

<matplotlib.figure.Figure at 0x7f0841763b70>

<matplotlib.figure.Figure at 0x7f084164d588>

<matplotlib.figure.Figure at 0x7f0841591748>

<matplotlib.figure.Figure at 0x7f0841395208>

<matplotlib.figure.Figure at 0x7f08412cfd30>

<matplotlib.figure.Figure at 0x7f084131a8d0>

<matplotlib.figure.Figure at 0x7f08413c63c8>

<matplotlib.figure.Figure at 0x7f0841010ac8>

<matplotlib.figure.Figure at 0x7f0840f67278>

<matplotlib.figure.Figure at 0x7f0840e43e10>

<matplotlib.figure.Figure at 0x7f0840db0f60>

<matplotlib.figure.Figure at 0x7f0840cf5eb8>

<matplotlib.figure.Figure at 0x7f0840d08fd0>

<matplotlib.figure.Figure at 0x7f0840b394e0>

<matplotlib.figure.Figure at 0x7f0840ba06d8>

<matplotlib.figure.Figure at 0x7f084096cfd0>

<matplotlib.figure.Figure at 0x7f08408bb550>

<matplotlib.figure.Figure at 0x7f08408c03c8>

<matplotlib.figure.Figure at 0x7f08406c5d30>

<matplotlib.figure.Figure at 0x7f084062d160>

<matplotlib.figure.Figure at 0x7f084064a588>

<matplotlib.figure.Figure at 0x7f08404b0c18>

<matplotlib.figure.Figure at 0x7f08404ff3c8>

<matplotlib.figure.Figure at 0x7f08404a1630>

<matplotlib.figure.Figure at 0x7f08401f63c8>

<matplotlib.figure.Figure at 0x7f084010d160>

<matplotlib.figure.Figure at 0x7f08400c6240>

<matplotlib.figure.Figure at 0x7f0840099588>

<matplotlib.figure.Figure at 0x7f0833eb2c18>

<matplotlib.figure.Figure at 0x7f0833f0d4a8>

<matplotlib.figure.Figure at 0x7f0833dc3128>

<matplotlib.figure.Figure at 0x7f0833d0a080>

<matplotlib.figure.Figure at 0x7f0833c96908>

<matplotlib.figure.Figure at 0x7f0833ceb780>

<matplotlib.figure.Figure at 0x7f0833a81f98>

<matplotlib.figure.Figure at 0x7f0833891f60>

<matplotlib.figure.Figure at 0x7f083392d7f0>

<matplotlib.figure.Figure at 0x7f0833807860>

<matplotlib.figure.Figure at 0x7f08337a55f8>

<matplotlib.figure.Figure at 0x7f0833543240>

<matplotlib.figure.Figure at 0x7f08335902b0>

<matplotlib.figure.Figure at 0x7f08333a9b00>

<matplotlib.figure.Figure at 0x7f0833406898>

<matplotlib.figure.Figure at 0x7f0833308240>

<matplotlib.figure.Figure at 0x7f083325aa58>

<matplotlib.figure.Figure at 0x7f08331b21d0>

<matplotlib.figure.Figure at 0x7f0833260a58>

<matplotlib.figure.Figure at 0x7f0832ff3f98>

<matplotlib.figure.Figure at 0x7f0832de0828>