In [76]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [77]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

def sample_z(n,m):
    return np.random.uniform(-1.,1., size=[n,m])

In [78]:
# Defining input placeholder and parameters for the generator

Z = tf.placeholder(tf.float32, shape=[None, 100])

G_W1 = tf.Variable(xavier_init([100, 128]))
G_b1 = tf.Variable(tf.zeros(shape=[128]))

G_W2 = tf.Variable(xavier_init([128, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))

theta_G = [G_W1, G_W2, G_b1, G_b2]

In [79]:
# Defining input placeholder and parameters for the discriminator

X = tf.placeholder(tf.float32, shape=[None, 784])

D_W1 = tf.Variable(xavier_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))

D_W2 = tf.Variable(xavier_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]

In [80]:
def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob
    

In [81]:
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_log_prob = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_log_prob)
    
    return D_prob

In [82]:
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

In [83]:
G_sample = generator(Z)
D_real = discriminator(X)
D_fake = discriminator(G_sample)

In [84]:
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

In [85]:
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list = theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list = theta_G)

In [86]:
z_size = 100
mb_size = 16

mnist = input_data.read_data_sets('../data/MNIST_data', one_hot=True)

Extracting ../data/MNIST_data/train-images-idx3-ubyte.gz
Extracting ../data/MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../data/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../data/MNIST_data/t10k-labels-idx1-ubyte.gz


In [87]:
sess = tf.Session()
sess.run(tf.initialize_all_variables())

if not os.path.exists('../out/'):
    os.makedirs('../out/')

i = 0

for it in range(1000000):
    if it % 100 == 0:
        samples = sess.run(G_sample, feed_dict={Z: sample_z(16, z_size)})

        fig = plot(samples)
        plt.savefig('../out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

    X_mb, _ = mnist.train.next_batch(mb_size)

    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_z(mb_size, z_size)})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_z(mb_size, z_size)})


    if it % 100 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

Iter: 0
D loss: 1.931
G_loss: 1.447

Iter: 100
D loss: 0.201
G_loss: 2.865

Iter: 200
D loss: 0.1256
G_loss: 4.069

Iter: 300
D loss: 0.05122
G_loss: 5.108

Iter: 400
D loss: 0.03928
G_loss: 4.418

Iter: 500
D loss: 0.01165
G_loss: 5.415

Iter: 600
D loss: 0.0198
G_loss: 5.793

Iter: 700
D loss: 0.02997
G_loss: 5.906

Iter: 800
D loss: 0.02053
G_loss: 5.663

Iter: 900
D loss: 0.02277
G_loss: 6.541

Iter: 1000
D loss: 0.00657
G_loss: 5.82

Iter: 1100
D loss: 0.007508
G_loss: 6.908

Iter: 1200
D loss: 0.007557
G_loss: 6.701

Iter: 1300
D loss: 0.008096
G_loss: 6.952

Iter: 1400
D loss: 0.08425
G_loss: 5.466

Iter: 1500
D loss: 0.009643
G_loss: 6.054

Iter: 1600
D loss: 0.01788
G_loss: 5.21

Iter: 1700
D loss: 0.0343
G_loss: 6.221

Iter: 1800
D loss: 0.02855
G_loss: 7.148

Iter: 1900
D loss: 0.01495
G_loss: 8.293

Iter: 2000
D loss: 0.01449
G_loss: 8.317

Iter: 2100
D loss: 0.03741
G_loss: 6.827

Iter: 2200
D loss: 0.04139
G_loss: 5.882

Iter: 2300
D loss: 0.12
G_loss: 4.22

Iter: 2400
D 

Iter: 19900
D loss: 0.9584
G_loss: 2.052

Iter: 20000
D loss: 0.4758
G_loss: 1.977

Iter: 20100
D loss: 0.469
G_loss: 2.825

Iter: 20200
D loss: 1.161
G_loss: 2.11

Iter: 20300
D loss: 0.3483
G_loss: 2.664

Iter: 20400
D loss: 0.5321
G_loss: 1.978

Iter: 20500
D loss: 0.7435
G_loss: 2.395

Iter: 20600
D loss: 0.6217
G_loss: 2.216

Iter: 20700
D loss: 0.7262
G_loss: 1.695

Iter: 20800
D loss: 0.7475
G_loss: 1.871

Iter: 20900
D loss: 0.7721
G_loss: 2.007

Iter: 21000
D loss: 0.6784
G_loss: 2.39

Iter: 21100
D loss: 0.4128
G_loss: 2.097

Iter: 21200
D loss: 0.6011
G_loss: 1.957

Iter: 21300
D loss: 0.8755
G_loss: 2.183

Iter: 21400
D loss: 0.5624
G_loss: 1.784

Iter: 21500
D loss: 0.533
G_loss: 2.184

Iter: 21600
D loss: 0.7923
G_loss: 2.03

Iter: 21700
D loss: 0.746
G_loss: 2.018

Iter: 21800
D loss: 0.5703
G_loss: 1.951

Iter: 21900
D loss: 0.8785
G_loss: 2.145

Iter: 22000
D loss: 0.7207
G_loss: 2.553

Iter: 22100
D loss: 0.4529
G_loss: 1.812

Iter: 22200
D loss: 0.7776
G_loss: 2.609


Iter: 39600
D loss: 0.7183
G_loss: 1.277

Iter: 39700
D loss: 0.5792
G_loss: 1.964

Iter: 39800
D loss: 0.7553
G_loss: 2.274

Iter: 39900
D loss: 1.029
G_loss: 2.676

Iter: 40000
D loss: 0.8432
G_loss: 1.952

Iter: 40100
D loss: 0.6053
G_loss: 1.966

Iter: 40200
D loss: 1.201
G_loss: 1.934

Iter: 40300
D loss: 0.5597
G_loss: 1.787

Iter: 40400
D loss: 0.642
G_loss: 2.492

Iter: 40500
D loss: 0.88
G_loss: 1.998

Iter: 40600
D loss: 0.4501
G_loss: 2.452

Iter: 40700
D loss: 0.8737
G_loss: 1.949

Iter: 40800
D loss: 0.5452
G_loss: 2.471

Iter: 40900
D loss: 0.5603
G_loss: 1.674

Iter: 41000
D loss: 1.021
G_loss: 2.205

Iter: 41100
D loss: 0.5172
G_loss: 2.222

Iter: 41200
D loss: 0.9747
G_loss: 1.993

Iter: 41300
D loss: 0.9335
G_loss: 2.674

Iter: 41400
D loss: 0.3913
G_loss: 2.397

Iter: 41500
D loss: 0.373
G_loss: 2.065

Iter: 41600
D loss: 0.655
G_loss: 1.806

Iter: 41700
D loss: 0.9852
G_loss: 2.356

Iter: 41800
D loss: 0.9145
G_loss: 1.522

Iter: 41900
D loss: 1.117
G_loss: 1.821

I

Iter: 59300
D loss: 0.7703
G_loss: 2.032

Iter: 59400
D loss: 0.855
G_loss: 1.541

Iter: 59500
D loss: 0.7133
G_loss: 1.705

Iter: 59600
D loss: 0.8965
G_loss: 2.157

Iter: 59700
D loss: 0.6311
G_loss: 2.309

Iter: 59800
D loss: 0.5197
G_loss: 1.977

Iter: 59900
D loss: 0.7405
G_loss: 1.966

Iter: 60000
D loss: 0.6391
G_loss: 2.243

Iter: 60100
D loss: 0.5722
G_loss: 2.087

Iter: 60200
D loss: 0.3966
G_loss: 2.453

Iter: 60300
D loss: 0.6132
G_loss: 1.768

Iter: 60400
D loss: 0.5284
G_loss: 1.766

Iter: 60500
D loss: 0.9968
G_loss: 2.791

Iter: 60600
D loss: 0.7946
G_loss: 2.601

Iter: 60700
D loss: 0.3236
G_loss: 2.382

Iter: 60800
D loss: 0.7271
G_loss: 2.422

Iter: 60900
D loss: 1.085
G_loss: 2.068

Iter: 61000
D loss: 0.5682
G_loss: 2.869

Iter: 61100
D loss: 0.5483
G_loss: 1.9

Iter: 61200
D loss: 0.5678
G_loss: 2.23

Iter: 61300
D loss: 0.7972
G_loss: 2.156

Iter: 61400
D loss: 0.7756
G_loss: 2.135

Iter: 61500
D loss: 0.7089
G_loss: 2.724

Iter: 61600
D loss: 0.3723
G_loss: 2.78

Iter: 79000
D loss: 0.783
G_loss: 3.409

Iter: 79100
D loss: 0.8262
G_loss: 2.694

Iter: 79200
D loss: 0.542
G_loss: 1.895

Iter: 79300
D loss: 0.6954
G_loss: 2.216

Iter: 79400
D loss: 0.4691
G_loss: 2.089

Iter: 79500
D loss: 0.6429
G_loss: 3.269

Iter: 79600
D loss: 0.5915
G_loss: 2.997

Iter: 79700
D loss: 0.4638
G_loss: 2.614

Iter: 79800
D loss: 0.3392
G_loss: 3.004

Iter: 79900
D loss: 0.6663
G_loss: 1.962

Iter: 80000
D loss: 0.5957
G_loss: 2.119

Iter: 80100
D loss: 0.6005
G_loss: 2.102

Iter: 80200
D loss: 0.6397
G_loss: 2.475

Iter: 80300
D loss: 1.098
G_loss: 3.365

Iter: 80400
D loss: 0.808
G_loss: 2.961

Iter: 80500
D loss: 0.3965
G_loss: 2.155

Iter: 80600
D loss: 0.3256
G_loss: 2.908

Iter: 80700
D loss: 0.6767
G_loss: 2.921

Iter: 80800
D loss: 0.5793
G_loss: 2.492

Iter: 80900
D loss: 0.1688
G_loss: 2.085

Iter: 81000
D loss: 0.6503
G_loss: 2.172

Iter: 81100
D loss: 0.6432
G_loss: 2.477

Iter: 81200
D loss: 0.7398
G_loss: 2.671

Iter: 81300
D loss: 0.2063
G_loss: 3.2

Iter: 98700
D loss: 0.2963
G_loss: 2.638

Iter: 98800
D loss: 0.9836
G_loss: 2.998

Iter: 98900
D loss: 0.7176
G_loss: 2.278

Iter: 99000
D loss: 0.651
G_loss: 3.087

Iter: 99100
D loss: 0.7889
G_loss: 2.441

Iter: 99200
D loss: 0.49
G_loss: 2.509

Iter: 99300
D loss: 0.6354
G_loss: 2.328

Iter: 99400
D loss: 0.3629
G_loss: 2.495

Iter: 99500
D loss: 0.9369
G_loss: 3.141

Iter: 99600
D loss: 0.3937
G_loss: 2.927

Iter: 99700
D loss: 0.3924
G_loss: 3.063

Iter: 99800
D loss: 0.8488
G_loss: 1.869

Iter: 99900
D loss: 0.375
G_loss: 2.505

Iter: 100000
D loss: 0.6736
G_loss: 2.779

Iter: 100100
D loss: 0.825
G_loss: 3.01

Iter: 100200
D loss: 0.6031
G_loss: 2.438

Iter: 100300
D loss: 0.4209
G_loss: 2.25

Iter: 100400
D loss: 0.597
G_loss: 2.456

Iter: 100500
D loss: 0.3456
G_loss: 2.61

Iter: 100600
D loss: 0.5841
G_loss: 3.322

Iter: 100700
D loss: 0.7911
G_loss: 2.027

Iter: 100800
D loss: 0.3785
G_loss: 2.814

Iter: 100900
D loss: 0.6058
G_loss: 2.941

Iter: 101000
D loss: 0.4559
G_los

Iter: 117900
D loss: 0.5728
G_loss: 2.246

Iter: 118000
D loss: 0.3385
G_loss: 2.159

Iter: 118100
D loss: 0.8794
G_loss: 2.908

Iter: 118200
D loss: 0.4917
G_loss: 2.39

Iter: 118300
D loss: 0.4197
G_loss: 3.284

Iter: 118400
D loss: 0.5018
G_loss: 2.979

Iter: 118500
D loss: 0.484
G_loss: 2.734

Iter: 118600
D loss: 0.3077
G_loss: 2.481

Iter: 118700
D loss: 0.4931
G_loss: 3.287

Iter: 118800
D loss: 0.2313
G_loss: 2.58

Iter: 118900
D loss: 0.542
G_loss: 2.528

Iter: 119000
D loss: 0.632
G_loss: 1.61

Iter: 119100
D loss: 0.6195
G_loss: 3.467

Iter: 119200
D loss: 0.3836
G_loss: 2.051

Iter: 119300
D loss: 0.4156
G_loss: 3.635

Iter: 119400
D loss: 0.5654
G_loss: 2.533

Iter: 119500
D loss: 0.6352
G_loss: 2.264

Iter: 119600
D loss: 0.4369
G_loss: 2.794

Iter: 119700
D loss: 0.3412
G_loss: 3.05

Iter: 119800
D loss: 0.4444
G_loss: 2.559

Iter: 119900
D loss: 0.8284
G_loss: 2.989

Iter: 120000
D loss: 0.4716
G_loss: 2.37

Iter: 120100
D loss: 0.4058
G_loss: 2.464

Iter: 120200
D loss

Iter: 137100
D loss: 0.5902
G_loss: 2.419

Iter: 137200
D loss: 0.5225
G_loss: 2.41

Iter: 137300
D loss: 0.4543
G_loss: 2.898

Iter: 137400
D loss: 0.8566
G_loss: 2.423

Iter: 137500
D loss: 0.653
G_loss: 2.587

Iter: 137600
D loss: 0.2082
G_loss: 2.707

Iter: 137700
D loss: 0.6104
G_loss: 2.249

Iter: 137800
D loss: 0.4501
G_loss: 1.817

Iter: 137900
D loss: 0.341
G_loss: 3.042

Iter: 138000
D loss: 0.7171
G_loss: 2.489

Iter: 138100
D loss: 0.615
G_loss: 2.417

Iter: 138200
D loss: 0.392
G_loss: 2.568

Iter: 138300
D loss: 0.3351
G_loss: 3.355

Iter: 138400
D loss: 0.4987
G_loss: 1.939

Iter: 138500
D loss: 0.5686
G_loss: 2.083

Iter: 138600
D loss: 0.488
G_loss: 2.594

Iter: 138700
D loss: 0.3986
G_loss: 2.298

Iter: 138800
D loss: 0.3699
G_loss: 2.277

Iter: 138900
D loss: 0.3118
G_loss: 2.532

Iter: 139000
D loss: 0.6108
G_loss: 2.133

Iter: 139100
D loss: 0.9255
G_loss: 2.487

Iter: 139200
D loss: 0.4445
G_loss: 2.635

Iter: 139300
D loss: 0.5309
G_loss: 1.598

Iter: 139400
D lo

Iter: 156300
D loss: 0.4984
G_loss: 1.762

Iter: 156400
D loss: 0.2554
G_loss: 2.86

Iter: 156500
D loss: 0.2169
G_loss: 1.854

Iter: 156600
D loss: 0.6696
G_loss: 3.207

Iter: 156700
D loss: 0.4459
G_loss: 2.989

Iter: 156800
D loss: 0.6038
G_loss: 2.234

Iter: 156900
D loss: 0.6566
G_loss: 2.164

Iter: 157000
D loss: 0.2953
G_loss: 2.621

Iter: 157100
D loss: 0.5267
G_loss: 2.623

Iter: 157200
D loss: 0.3813
G_loss: 2.027

Iter: 157300
D loss: 0.4148
G_loss: 2.975

Iter: 157400
D loss: 0.3946
G_loss: 1.788

Iter: 157500
D loss: 0.5339
G_loss: 3.341

Iter: 157600
D loss: 0.5921
G_loss: 3.04

Iter: 157700
D loss: 0.3822
G_loss: 2.8

Iter: 157800
D loss: 0.4935
G_loss: 3.062

Iter: 157900
D loss: 0.5313
G_loss: 2.144

Iter: 158000
D loss: 0.78
G_loss: 1.838

Iter: 158100
D loss: 0.4565
G_loss: 2.904

Iter: 158200
D loss: 0.4598
G_loss: 3.163

Iter: 158300
D loss: 0.387
G_loss: 3.031

Iter: 158400
D loss: 0.7668
G_loss: 2.942

Iter: 158500
D loss: 0.5205
G_loss: 2.567

Iter: 158600
D los

Iter: 175500
D loss: 0.174
G_loss: 2.572

Iter: 175600
D loss: 0.7762
G_loss: 3.068

Iter: 175700
D loss: 0.8696
G_loss: 3.178

Iter: 175800
D loss: 0.8229
G_loss: 3.086

Iter: 175900
D loss: 0.5462
G_loss: 2.316

Iter: 176000
D loss: 0.9682
G_loss: 2.4

Iter: 176100
D loss: 0.7435
G_loss: 3.718

Iter: 176200
D loss: 0.4367
G_loss: 2.553

Iter: 176300
D loss: 0.5826
G_loss: 2.0

Iter: 176400
D loss: 0.5966
G_loss: 2.819

Iter: 176500
D loss: 0.4403
G_loss: 3.445

Iter: 176600
D loss: 1.035
G_loss: 2.945

Iter: 176700
D loss: 0.7669
G_loss: 2.341

Iter: 176800
D loss: 0.6736
G_loss: 2.78

Iter: 176900
D loss: 0.1506
G_loss: 3.422

Iter: 177000
D loss: 1.176
G_loss: 2.665

Iter: 177100
D loss: 0.463
G_loss: 2.323

Iter: 177200
D loss: 0.7541
G_loss: 2.5

Iter: 177300
D loss: 0.5453
G_loss: 2.778

Iter: 177400
D loss: 0.4583
G_loss: 3.414

Iter: 177500
D loss: 0.4869
G_loss: 2.473

Iter: 177600
D loss: 0.4339
G_loss: 2.278

Iter: 177700
D loss: 0.3542
G_loss: 2.701

Iter: 177800
D loss: 0

  dtype = np.min_scalar_type(value)
  order=order, subok=True, ndmin=ndmin)


Iter: 181100
D loss: nan
G_loss: nan

Iter: 181200
D loss: nan
G_loss: nan

Iter: 181300
D loss: nan
G_loss: nan

Iter: 181400
D loss: nan
G_loss: nan

Iter: 181500
D loss: nan
G_loss: nan

Iter: 181600
D loss: nan
G_loss: nan

Iter: 181700
D loss: nan
G_loss: nan

Iter: 181800
D loss: nan
G_loss: nan

Iter: 181900
D loss: nan
G_loss: nan

Iter: 182000
D loss: nan
G_loss: nan

Iter: 182100
D loss: nan
G_loss: nan

Iter: 182200
D loss: nan
G_loss: nan

Iter: 182300
D loss: nan
G_loss: nan

Iter: 182400
D loss: nan
G_loss: nan

Iter: 182500
D loss: nan
G_loss: nan

Iter: 182600
D loss: nan
G_loss: nan

Iter: 182700
D loss: nan
G_loss: nan

Iter: 182800
D loss: nan
G_loss: nan

Iter: 182900
D loss: nan
G_loss: nan

Iter: 183000
D loss: nan
G_loss: nan

Iter: 183100
D loss: nan
G_loss: nan

Iter: 183200
D loss: nan
G_loss: nan

Iter: 183300
D loss: nan
G_loss: nan

Iter: 183400
D loss: nan
G_loss: nan

Iter: 183500
D loss: nan
G_loss: nan

Iter: 183600
D loss: nan
G_loss: nan

Iter: 183700

KeyboardInterrupt: 