# MNIST GAN

## Obtencion del dataset

In [1]:
#!rm -rf data
#!git clone https://github.com/juanma1982/GANNets.git data

In [3]:
!ls -l data/mnist_png/training/4 | head

total 23368
-rw-r----- 1 gaston gaston 255 dic 10  2015 10013.png
-rw-r----- 1 gaston gaston 219 dic 10  2015 10018.png
-rw-r----- 1 gaston gaston 265 dic 10  2015 10033.png
-rw-r----- 1 gaston gaston 159 dic 10  2015 1004.png
-rw-r----- 1 gaston gaston 256 dic 10  2015 10060.png
-rw-r----- 1 gaston gaston 233 dic 10  2015 1006.png
-rw-r----- 1 gaston gaston 275 dic 10  2015 1008.png
-rw-r----- 1 gaston gaston 180 dic 10  2015 10103.png
-rw-r----- 1 gaston gaston 237 dic 10  2015 10104.png
ls: write error: Broken pipe


## Imports & constants

In [4]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

import numpy as np
import tensorflow as tf

In [5]:
TRAINING_DATA_DIR = os.path.join("data", "mnist_png", "training")
CHECKPOINT_DIR = "checkpoints"

# Number of inputs counting both mnist data and generated data for the discriminator, and number of random inputs for
# the generator
BATCH_SIZE = 60

GEN_LEARNING_RATE = 1e-4
DISC_LEARNING_RATE = 1e-4

LATENT_SPACE_SHAPE = 100

GEN_VARIABLE_SCOPE = "generator"
DISC_VARIABLE_SCOPE = "discriminator"

MAX_STEPS = 1000000

## Funciones del modelo y entrenamiento

In [6]:
def generator(latent_space, label, training=True):
    """
    Defines the generator network using the latent_space as input.
    Args:
        latent_space: input for the generator network
        label: 10 dimensioanl one-hot tensor
    Returns:
        Generated images
    """
    with tf.variable_scope(GEN_VARIABLE_SCOPE):
        net = tf.concat([latent_space, label], axis=1)
        net = tf.layers.dense(net, 7 * 7 * 64, activation=None, use_bias=False)
        net = tf.layers.batch_normalization(net, training=training)

        # 7 x 7
        net = tf.reshape(net, [-1, 7, 7, 64])

        # 7 x 7
        net = tf.layers.conv2d_transpose(net, 64, kernel_size=(5, 5),
                                         strides=(1, 1),
                                         activation=None,
                                         padding='same',
                                         use_bias=False)
        net = tf.layers.batch_normalization(net, training=training)
        net = tf.nn.leaky_relu(net)

        # 14 x 14
        net = tf.layers.conv2d_transpose(net, 32, kernel_size=(5, 5),
                                         strides=(2, 2),
                                         activation=None,
                                         padding='same',
                                         use_bias=False)
        net = tf.layers.batch_normalization(net, training=training)
        net = tf.nn.leaky_relu(net)

        # 28 x 28
        images = tf.layers.conv2d_transpose(net, 1, kernel_size=(5, 5),
                                            strides=(2, 2),
                                            activation=tf.nn.sigmoid,
                                            padding='same',
                                            use_bias=False)
        return images


def discriminator(images, label, training=True):
    """Defines the discriminator network
    Args:
        images: input images as 28x28 tensors
        label: 10 dimensioanl one-hot tensor
    Returns:
        Logits and prediction for each image
    """
    with tf.variable_scope(DISC_VARIABLE_SCOPE, reuse=tf.AUTO_REUSE):
        net = images
        net = tf.layers.conv2d(net, 64, kernel_size=(5, 5), strides=(2, 2),
                               activation=tf.nn.leaky_relu, padding='same')
        net = tf.layers.dropout(net, training=training)

        net = tf.layers.conv2d(net, 128, kernel_size=(5, 5), strides=(2, 2),
                               activation=tf.nn.leaky_relu, padding='same')
        net = tf.layers.dropout(net, training=training)

        net_shape = net.shape
        net_reshaped = tf.reshape(net, [-1,
                                        net_shape[1] * net_shape[2] * net_shape[
                                            3]])
        net_with_label = tf.concat([net_reshaped, label], axis=1)
        logits = tf.layers.dense(net_with_label, 1,
                             activation=None)

        return logits


def _parse_function(filename, label):
    """
    Reads an image from a file, decodes it into a dense tensor, and resizes it to a fixed shape.
    """
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_png(image_string)
    image_resized = tf.reshape(image_decoded, [28, 28, 1])
    return tf.cast(image_resized, tf.float32) / 255, label


def _mnist_filenames_and_labels():
    """
    Returns:
        A tuple of lists, where the first list contains the mnist png file paths, and the second list contains the
        label for each image.
    """
    images_paths = []
    images_labels = []
    for label in range(10):
        images_dir = os.path.join(TRAINING_DATA_DIR, str(label))
        current_images_paths = os.listdir(images_dir)
        images_paths += list(
            map(lambda image_path: os.path.join(images_dir, image_path),
                current_images_paths))
        images_labels += [label] * len(current_images_paths)
    return images_paths, images_labels


def shuffle(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]
  
def _generator_step(sess):
    latent_space_np = np.random.randn(BATCH_SIZE, LATENT_SPACE_SHAPE)
    label = np.random.randint(10, size=BATCH_SIZE)
    _, G_loss_np, step_value = sess.run([G_optimizer, G_loss, step],
                            feed_dict={latent_space: latent_space_np,
                                       G_label: label})
    if step_value % 97 == 0:
        print()
        print("Step: ", sess.run(step))
        print("G_loss: ", G_loss_np)


def _discriminator_step(sess):
    latent_space_np = np.random.randn(BATCH_SIZE // 2, LATENT_SPACE_SHAPE)
    label = np.random.randint(10, size=BATCH_SIZE // 2)
    _, D_loss_np, step_value = sess.run([D_optimizer, D_loss, step],
                            feed_dict={latent_space: latent_space_np,
                                       G_label: label})
    if step_value % 97 == 0:
        print()
        print("Step: ", sess.run(step))
        print("D_loss: ", D_loss_np)

## Definición del grafo de TF y ejecución de la Session

In [None]:
filenames, labels = _mnist_filenames_and_labels()
filenames, labels = shuffle(np.array(filenames), np.array(labels))

step = tf.train.get_or_create_global_step()

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
dataset = dataset.shuffle(buffer_size=1000)

# The other half of the batch will come from the generator.
dataset = dataset.batch(batch_size=BATCH_SIZE // 2)

dataset = dataset.repeat()

iterator = dataset.make_one_shot_iterator()

# Iterator for tuples of images and labels
next = iterator.get_next()

latent_space = tf.placeholder(tf.float32, shape=[None, LATENT_SPACE_SHAPE])
G_label = tf.placeholder(tf.int32, shape=[None])
G_label_one_hot = tf.one_hot(G_label, 10)
G_images = generator(latent_space, G_label_one_hot)

D_fake_logits = discriminator(G_images, G_label_one_hot)

real_image = next[0]
real_label = tf.one_hot(next[1], 10)
D_real_logits = discriminator(real_image, real_label)

G_expected = tf.ones_like(D_fake_logits)
G_loss = tf.losses.sigmoid_cross_entropy(G_expected, D_fake_logits)

D_real_expected = tf.ones_like(D_real_logits)
D_fake_expected = tf.zeros_like(D_fake_logits)

D_real_loss = tf.losses.sigmoid_cross_entropy(D_real_expected, D_real_logits)
D_fake_loss = tf.losses.sigmoid_cross_entropy(D_fake_expected, D_fake_logits)
D_loss = D_real_loss + D_fake_loss

G_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                GEN_VARIABLE_SCOPE)

G_optimizer = tf.train.AdamOptimizer(
    learning_rate=GEN_LEARNING_RATE).minimize(G_loss, var_list=G_variables,
                                              global_step=tf.train.get_global_step())

D_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                DISC_VARIABLE_SCOPE)

D_optimizer = tf.train.AdamOptimizer(
    learning_rate=DISC_LEARNING_RATE).minimize(D_loss, var_list=D_variables,
                                               global_step=tf.train.get_global_step())

tf.summary.scalar("Gen loss", G_loss, family="Generator")
tf.summary.scalar("Disc loss", D_loss, family="Discriminator")
tf.summary.image("Gen images", G_images, max_outputs=8)


hooks = [tf.train.StopAtStepHook(num_steps=MAX_STEPS)]

with tf.train.MonitoredTrainingSession(checkpoint_dir=CHECKPOINT_DIR,
                                       hooks=hooks) as sess:
    while not sess.should_stop():
        _generator_step(sess)
        _discriminator_step(sess)

INFO:tensorflow:Summary name Gen loss is illegal; using Gen_loss instead.
INFO:tensorflow:Summary name Disc loss is illegal; using Disc_loss instead.
INFO:tensorflow:Summary name Gen images is illegal; using Gen_images instead.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into checkpoints/model.ckpt.

Step:  1
G_loss:  0.6910719

Step:  98
D_loss:  0.48531622
INFO:tensorflow:global_step/sec: 30.8224

Step:  195
G_loss:  3.454005
INFO:tensorflow:global_step/sec: 55.3916

Step:  292
D_loss:  0.114902675
INFO:tensorflow:global_step/sec: 55.0462

Step:  389
G_loss:  3.1525147
INFO:tensorflow:global_step/sec: 54.2652

Step:  486
D_loss:  0.08261413
INFO:tensorflow:global_step/sec: 53.3767

Step:  583
G_loss:  2.9732738
INFO:tensorflow:global_step/sec: 51.0967

Step:  680
D_loss:  0.20242587
INFO:tensorflow:global_step/sec: 51.3103

S

INFO:tensorflow:global_step/sec: 51.0843

Step:  10574
D_loss:  1.0468056
INFO:tensorflow:global_step/sec: 50.7788

Step:  10671
G_loss:  0.8620329
INFO:tensorflow:global_step/sec: 51.7008

Step:  10768
D_loss:  1.0273147
INFO:tensorflow:global_step/sec: 51.5911

Step:  10865
G_loss:  1.2551532
INFO:tensorflow:global_step/sec: 51.7048

Step:  10962
D_loss:  1.142829
INFO:tensorflow:global_step/sec: 52.1073

Step:  11059
G_loss:  1.0654374
INFO:tensorflow:global_step/sec: 50.3796

Step:  11156
D_loss:  1.1003778
INFO:tensorflow:global_step/sec: 50.3033

Step:  11253
G_loss:  0.8643943
INFO:tensorflow:global_step/sec: 50.6153

Step:  11350
D_loss:  1.2817659
INFO:tensorflow:global_step/sec: 51.657

Step:  11447
G_loss:  0.9242514
INFO:tensorflow:global_step/sec: 51.9604

Step:  11544
D_loss:  1.0074841
INFO:tensorflow:global_step/sec: 52.9202

Step:  11641
G_loss:  0.73917747
INFO:tensorflow:global_step/sec: 51.6229

Step:  11738
D_loss:  1.0351363
INFO:tensorflow:global_step/sec: 52.475


Step:  21438
D_loss:  1.3953488
INFO:tensorflow:global_step/sec: 53.0072

Step:  21535
G_loss:  0.9093867
INFO:tensorflow:global_step/sec: 53.5082

Step:  21632
D_loss:  1.1908513
INFO:tensorflow:global_step/sec: 53.4836

Step:  21729
G_loss:  0.8162834
INFO:tensorflow:global_step/sec: 53.7819

Step:  21826
D_loss:  1.1825252
INFO:tensorflow:global_step/sec: 53.803

Step:  21923
G_loss:  0.9027103
INFO:tensorflow:global_step/sec: 53.8258

Step:  22020
D_loss:  1.3576818
INFO:tensorflow:global_step/sec: 53.7559

Step:  22117
G_loss:  0.8049272
INFO:tensorflow:global_step/sec: 53.7737

Step:  22214
D_loss:  1.4398282
INFO:tensorflow:global_step/sec: 53.5491

Step:  22311
G_loss:  0.8382356
INFO:tensorflow:global_step/sec: 53.7735

Step:  22408
D_loss:  1.2645679
INFO:tensorflow:global_step/sec: 53.7626

Step:  22505
G_loss:  0.8024385
INFO:tensorflow:global_step/sec: 53.6904

Step:  22602
D_loss:  1.1836151

Step:  22699
G_loss:  0.9198995
INFO:tensorflow:global_step/sec: 53.7285

Step:

INFO:tensorflow:global_step/sec: 53.2876

Step:  32302
D_loss:  1.3430438

Step:  32399
G_loss:  0.8543028
INFO:tensorflow:global_step/sec: 53.3597

Step:  32496
D_loss:  1.2718141
INFO:tensorflow:global_step/sec: 53.449

Step:  32593
G_loss:  0.84265
INFO:tensorflow:global_step/sec: 52.3116

Step:  32690
D_loss:  1.2120482
INFO:tensorflow:global_step/sec: 53.6054

Step:  32787
G_loss:  0.92657065
INFO:tensorflow:global_step/sec: 53.7037

Step:  32884
D_loss:  1.0239387
INFO:tensorflow:global_step/sec: 53.5354

Step:  32981
G_loss:  0.8173711
INFO:tensorflow:global_step/sec: 53.6925

Step:  33078
D_loss:  1.24791
INFO:tensorflow:global_step/sec: 53.3165

Step:  33175
G_loss:  0.9486646
INFO:tensorflow:global_step/sec: 53.3107

Step:  33272
D_loss:  1.2804036
INFO:tensorflow:global_step/sec: 53.6072

Step:  33369
G_loss:  0.8640768
INFO:tensorflow:global_step/sec: 53.584

Step:  33466
D_loss:  1.2259965
INFO:tensorflow:global_step/sec: 53.4997

Step:  33563
G_loss:  0.8819663
INFO:tenso