In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# import tensorflow and tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds

# import random stuff that hopefully?? is going to be useful
import numpy as np
import random
import time
from matplotlib import pyplot as plt

In [1]:
# define hyperparameters
batch_size = 64
learning_rate = 0.0002
n_z_input = 150

# number of epochs and iterations per epoch
n_epoch = 40
iterations_per_epoch = 300

In [None]:
# takes in a series of graph_images ([#images, :, :, 0]) and plot them
# each image is displayed as a 50 by 50 pixel image
# where images are laid out in rows of 10 images and columns of 20
def display_graph(graph_images, title, shape=(10, 20), image_size=(50, 50)):
    fig = plt.figure(figsize=image_size) # define figure
    plt.title(title) # define title
    plt.axis('off') # remove axis
    for i in range(0, shape[0] * shape[1]):
        img = graph_images[i, :, :, 0]
        fig.add_subplot(shape[0], shape[1], i + 1)
        plt.axis('off')
        plt.imshow(img, cmap="gray")
    plt.show()

## Model

In [None]:
# generator and discriminator 
def generator(features, reuse=False):
    with tf.variable_scope('generator', reuse=reuse):
        # tf.layers.conv2d_transpose
        # # kernels | kernel dimension | stride | padding | activation
        #  512        [4, 4]            (1, 1),   "VALID"   relu
        #  256        [4, 4]            (4, 4),   "SAME"    relu
        #  128        [4, 4]            (2, 2),   "SAME"    relu
        #  1          [4, 4]            (2, 2),   "SAME"    relu
        # conv3 is the output of the last conv2d_transpose layer
        return tf.nn.tanh(conv4)

In [None]:
def discriminator(features, reuse=False):
    with tf.variable_scope('discriminator', reuse=reuse):
        # # kernels | kernel dimension | stride | padding | activation
        #  128        [4, 4]             (2, 2),  "SAME"    leaky_relu
        #  256        [4, 4]             (2, 2),  "SAME"    leaky_relu
        #  512        [4, 4]             (4, 4),  "SAME"    leaky_relu
        #  1024       [3, 3]             (1, 1),  "VALID"   leaky_relu
        # conv4 is the output of the last conv2d_transpose layer
        flatten = tf.contrib.layers.flatten(conv4)
        logits = tf.layers.dense(flatten, 1)
        # use sigmoid to squash output into a probability
        output = tf.nn.sigmoid(logits) 
        
        return output, logits

## Placeholders

In [None]:
x = tf.placeholder(tf.float32, shape=(None, 64, 64, 1))
z = tf.placeholder(tf.float32, shape=(None, 1, 1, n_z_input))

In [None]:
# generator is generating an image (g is the fake image)
g = generator(z, False)
# discriminator is classifying real images 
# the first output is the probability and the second is the logits
# to feed into sigmoid_cross_entropy_with_logits
disc_real, disc_real_logits = discriminator(x, False)
disc_fake, disc_fake_logits = discriminator(g, True)

# get accuracy of the discriminator in
# predicting that the image is real or fake

# the goal is to get a list of 1 and 0 whether it predicted right or not
# and use tf.reduce_mean to get mean of 1 and 0 which returns a probability
# you can use > or < 0.5 to get true and false 
# whether it predicted right or not and use tf.cast to cast that boolean into tf.float32
real_accuracy = tf.reduce_mean() # TODO
fake_accuracy = tf.reduce_mean() # TODO

Important! The labels are tf.ones([batch_size, 1]) instead of tf.ones([batch_size, 1, 1, 1]) because the output of the discriminator is a single probability i.e. [batch_size, 0.5] instead of [batch_size, 0.5, 1, 1] because its no longer the output of a conv layer but the output of a dense layer. 

In [None]:
# loss function for discriminator
disc_loss_real = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_real, labels=tf.ones([batch_size, 1]))
disc_loss_real = tf.reduce_mean(disc_loss_real)

disc_loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake, labels=tf.zeros([batch_size, 1]))
disc_loss_fake = tf.reduce_mean(disc_loss_fake)

disc_loss_total = disc_loss_real + disc_loss_fake

In [None]:
# loss function for generator
gen_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=d_fake, labels=tf.ones([None, 1]))
gen_loss = tf.reduce_mean(gen_loss)

In [None]:
t_vars = tf.trainable_variables()
disc_var = [var for var in t_vars if var.name.startswith('discriminator')]
gen_var = [var for var in t_vars if var.name.startswith('generator')]

In [None]:
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    gen_optimizer = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(gen_loss, var_list=gen_vars)
    disc_optimizer = tf.train.AdamOptimizer(lr, beta1=0.5).minimize(disc_loss, var_list=disc_vars)

### Create Dataset

In [None]:
mnist_builder = tfds.builder("mnist")
# use https://www.tensorflow.org/datasets/overview#datasetbuilder
# dont use split
datasets = #

In [None]:
train_dataset, test_dataset = datasets["train"], datasets["test"]

In [None]:
def apply_map(inputs):
    img = inputs['image']
    # 1. cast the img into float32
    # 2. image values are from 0 to 1. convert to range -1 to 1
    # by subtracting by a decimal so its range -0.5 to 0.5 and
    # then divide by a decimal so its not -1 to 1
    img = tf.math.tanh(img)
    # 3. use tf.image.resize to conver the image into a 64 by 64 image
    return img

In [None]:
train_dataset = train_dataset.map(apply_map)
train_dataset = train_dataset.shuffle(1024)
train_dataset = train_dataset.batch(batch_size)

iterator = train_dataset.make_initializable_iterator()
batch = iterator.get_next()

In [None]:
# z_test input into the generator with batch_size 10 
z_test = np.random.normal(0, 1, (10, 1, 1, n_z_input))

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
    # initialize variables
    sess.run(tf.global_variables_initializer())
    
    for epoch in range(train_epoch):
        print("Starting epoch: " + str(epoch))
        # initialize iterator
        sess.run(iterator.initializer)
        step = 0

        while True:
            try:
                tra_images = # grab the batch

                # makes sure the batch_size is 64
                if tra_images.shape[0] != 64:
                    break
                    
                # z vector of size batch_size
                z_batch = np.random.normal(0, 1, (batch_size, 1, 1, n_z_input))
                
                acc_fake, acc_real, loss_d, _, loss_g, _ = sess.run(
                  # TODO, which variables to run the session on
                  feed_dict={x: tra_images, z: z_batch})
                step += 1

                if step % 200 == 0:
                    generated_images = # get generated images that would result from feeding in z_test
                    display_graph(generated_images, "MNIST Images", (5, 2), image_size=(20, 20))
                print('Epoch: %d, Iteration: %d, loss_d: %.3f, loss_g: %.3f, acc_fake: %.3f, acc_real: %.3f' % (
                        epoch, step, loss_d, loss_g, acc_fake, acc_real))
                
            except tf.errors.OutOfRangeError:
                break