In [None]:
%matplotlib inline

import pickle

from keras.datasets import mnist
from keras.utils import to_categorical
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf

In [None]:
sns.set_style('whitegrid')

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_test = (x_test.astype(np.float32) - 127.5) / 127.5

x_train = np.expand_dims(x_train, axis=3)
x_test = np.expand_dims(x_test, axis=3)

In [None]:
def dense_discriminator(x, is_training, num_classes=10, reuse=False):
    with tf.variable_scope('discriminator', reuse=reuse) as scope:
        x = tf.layers.flatten(x)
        x = tf.layers.dense(x, 512, activation=tf.nn.leaky_relu)
        x = tf.layers.dense(x, 256, activation=tf.nn.leaky_relu)
        features = tf.layers.dense(x, 128, activation=tf.nn.leaky_relu)
        logits = tf.layers.dense(features, num_classes + 1)
        output = tf.nn.softmax(logits)
        return output, logits, features

In [None]:
def dense_generator(x, is_training, output_shape=(28, 28, 1), reuse=False):
    with tf.variable_scope('generator', reuse=reuse) as scope:
        x = tf.layers.dense(x, 256, activation=tf.nn.relu)
        x = tf.layers.batch_normalization(x, training=is_training)
        x = tf.layers.dense(x, 512, activation=tf.nn.relu)
        x = tf.layers.batch_normalization(x, training=is_training)
        x = tf.layers.dense(x, 1024, activation=tf.nn.relu)
        x = tf.layers.batch_normalization(x, training=is_training)
        x = tf.layers.dense(x, np.prod(output_shape), activation=tf.nn.tanh)
        x = tf.reshape(x, (-1,) + output_shape)
        
        return x

In [None]:
def build_dense_model(x_real, z, is_training, num_classes=10, output_shape=(28, 28, 1)):
    d_real_prob, d_real_logits, d_real_features = dense_discriminator(
        x_real, is_training, num_classes=num_classes, reuse=False,
    )
    x_fake = dense_generator(z, is_training, output_shape=output_shape)
    d_fake_prob, d_fake_logits, d_fake_features = dense_discriminator(
        x_fake, is_training, num_classes=num_classes, reuse=True,
    )
    return d_real_prob, d_real_logits, d_real_features, d_fake_prob, d_fake_logits, d_fake_features, x_fake

In [None]:
def standard_loss_accuracy(d_real_prob, d_real_logits, d_real_features,
                           d_fake_prob, d_fake_logits, d_fake_features,
                           extended_label, labeled_mask):
    epsilon = 1e-8
    
    ### Discriminator loss
    # Supervised loss for discriminator
    d_ce = tf.nn.softmax_cross_entropy_with_logits_v2(logits=d_real_logits,
                                                      labels=extended_label)
    d_loss_supervised = tf.reduce_sum(labeled_mask * d_ce) / (tf.reduce_sum(labeled_mask) + epsilon)
    # Unsupervised loss for discriminator
    # data is real
    # subtract from one due to log --> log of (1 - 0) is 0, therefore loss is 0 for d_real_prob[i, -1] == 0
    prob_real_be_real = 1 - d_real_prob[:, -1] + epsilon
    logprob = tf.log(prob_real_be_real)
    d_loss_unsupervised1 = -1 * tf.reduce_mean(logprob)
    # data is fake
    prob_fake_be_fake = d_fake_prob[:, -1] + epsilon
    logprob = tf.log(prob_fake_be_fake)
    d_loss_unsupervised2 = -1 * tf.reduce_mean(logprob)
    
    d_loss = d_loss_supervised + d_loss_unsupervised1 + d_loss_unsupervised2
    
    ### Generator loss
    # fake data is mistaken to be real
    prob_fake_be_real = 1 - d_fake_prob[:, -1] + epsilon
    logprob = tf.log(prob_fake_be_real)
    g_loss_probs = -1 * tf.reduce_mean(logprob)
    
    mean_real_features = tf.reduce_mean(d_real_features, axis=0)
    mean_fake_features = tf.reduce_mean(d_fake_features, axis=0)
    g_loss_fm = tf.reduce_mean(tf.square(mean_real_features - mean_fake_features))
    
    g_loss = g_loss_probs + g_loss_fm
    
    ### Accuracy
    correct_prediction = tf.equal(tf.argmax(d_real_prob[:, :-1], 1),
                                  tf.argmax(extended_label[:, :-1], 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    
    return d_loss_supervised, d_loss_unsupervised1, d_loss_unsupervised2, d_loss, g_loss, accuracy

In [None]:
def optimizer(d_loss, g_loss, d_learning_rate, g_learning_rate):
    extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(extra_update_ops):
        all_vars = tf.trainable_variables()
        d_vars = [var for var in all_vars if var.name.startswith('discriminator')]
        g_vars = [var for var in all_vars if var.name.startswith('generator')]

        d_optimizer = tf.train.AdamOptimizer(d_learning_rate).minimize(d_loss, var_list=d_vars)
        g_optimizer = tf.train.AdamOptimizer(g_learning_rate).minimize(g_loss, var_list=g_vars)
        return d_optimizer, g_optimizer

In [None]:
def extend_labels(labels):
    # add extra label for fake data
    extended_label = tf.concat([labels, tf.zeros([tf.shape(labels)[0], 1])], axis=1)

    return extended_label

In [None]:
def moving_average(x, n=10):
    ret = np.cumsum(x)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

In [None]:
def execute(x_train, y_train, x_test, y_test,
            epochs=50000, batch_size=32, test_steps=500, num_labeled_examples=None, 
            periodic_labeled_batch=False, periodic_labeled_batch_frequency=10,
            x_height=28, x_width=28, num_channels=1, latent_size=100):
    tf.reset_default_graph()
    
    num_classes = np.unique(y_train).shape[0]
    y_test = to_categorical(y_test, num_classes=num_classes)
    
    x = tf.placeholder(tf.float32, name='x', shape=(None, x_height, x_width, num_channels))
    label = tf.placeholder(tf.float32, name='label', shape=(None, num_classes))
    labeled_mask = tf.placeholder(tf.float32, name='labeled_mask', shape=(None,))
    z = tf.placeholder(tf.float32, name='z', shape=(None, latent_size))
    is_training = tf.placeholder(tf.bool, name = 'is_training')
    g_learning_rate = tf.placeholder(tf.float32, name='g_learning_rate')
    d_learning_rate = tf.placeholder(tf.float32, name='d_learning_rate')
    
    model = build_dense_model(x, z, is_training)
    extended_label = extend_labels(label)
    d_real_prob, d_real_logits, d_real_features, d_fake_prob, d_fake_logits, d_fake_features, x_fake = model
    loss_acc = standard_loss_accuracy(d_real_prob, d_real_logits, d_real_features,
                                      d_fake_prob, d_fake_logits, d_fake_features,
                                      extended_label, labeled_mask)
    _, _, _, d_loss, g_loss, accuracy = loss_acc
    d_optimizer, g_optimizer = optimizer(d_loss, g_loss, d_learning_rate, g_learning_rate)
    
    if num_labeled_examples is None:
        global_mask = np.ones(x_train.shape[0])
        periodic_labeled_batch = False
    else:
        global_mask = np.zeros(x_train.shape[0])
        for cls in np.unique(y_train):
            idx = y_train == cls
            idx = np.random.choice(np.flatnonzero(idx), num_labeled_examples // num_classes, replace=False)
            global_mask[idx] = 1.0
        
    
    train_d_losses, train_g_losses, train_accuracies = [], [], []
    test_d_losses, test_g_losses, test_accuracies = [], [], []
    
    def test_gan(epoch):
        test_size = x_test.shape[0]
        z_test = np.random.normal(0, 1, (test_size, latent_size))
        test_mask = np.ones(test_size)
        test_dictionary = {
            x: x_test,
            z: z_test,
            label: y_test,
            labeled_mask: test_mask,
            is_training: False
        }

        test_d_loss = d_loss.eval(feed_dict=test_dictionary)
        test_g_loss = g_loss.eval(feed_dict=test_dictionary)
        test_accuracy = accuracy.eval(feed_dict=test_dictionary)

        test_d_losses.append(test_d_loss)
        test_g_losses.append(test_g_loss)
        test_accuracies.append(test_accuracy)

        print(epoch, test_d_loss, test_g_loss, test_accuracy)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        
        for epoch in range(epochs):
            if periodic_labeled_batch and epoch % periodic_labeled_batch_frequency == 0:
                idx = np.flatnonzero(global_mask)
                x_batch = x_train[idx]
                y_batch = y_train[idx]
                mask = global_mask[idx]
            else:
                idx = np.random.randint(0, x_train.shape[0], batch_size)
                x_batch = x_train[idx]
                y_batch = y_train[idx]
                mask = global_mask[idx]
            y_batch = to_categorical(y_batch, num_classes=num_classes)
            z_batch = np.random.normal(0, 1, (batch_size, latent_size))
            train_dictionary = {
                x: x_batch,
                z: z_batch,
                label: y_batch,
                labeled_mask: mask,
                g_learning_rate: 0.001,
                d_learning_rate: 0.001,
                is_training: True
            }
            d_optimizer.run(feed_dict=train_dictionary)
            g_optimizer.run(feed_dict=train_dictionary)

            train_d_loss = d_loss.eval(feed_dict=train_dictionary)
            train_g_loss = g_loss.eval(feed_dict=train_dictionary)
            train_accuracy = accuracy.eval(feed_dict=train_dictionary)
            
            train_d_losses.append(train_d_loss)
            train_g_losses.append(train_g_loss)
            train_accuracies.append(train_accuracy)
            
            if epoch % test_steps == 0:
                test_gan(epoch)
        test_gan(epochs)
    
    return train_d_losses, train_g_losses, train_accuracies, test_d_losses, test_g_losses, test_accuracies

In [None]:
def run_test(output_file, num_labeled_examples=None,
             periodic_labeled_batch=False, periodic_labeled_batch_frequency=10):
    results = execute(x_train, y_train, x_test, y_test)
    train_d_losses, train_g_losses, train_accuracies, test_d_losses, test_g_losses, test_accuracies = results
    with open(output_file, 'wb') as f:
        pickle.dump({
            'train_d_losses': train_d_losses,
            'train_g_losses': train_g_losses,
            'train_accuracies': train_accuracies,
            'test_d_losses': test_d_losses,
            'test_g_losses': test_g_losses,
            'test_accuracies': test_accuracies
        }, f)
    return results

In [None]:
def plot_losses(results):
    train_d_losses, train_g_losses, _, test_d_losses, test_g_losses, _ = results
    
    average_train_d_losses = moving_average(train_d_losses, 10)
    average_train_g_losses = moving_average(train_g_losses, 10)

    plt.figure(figsize=(15, 8))
    plt.plot(np.arange(len(average_train_d_losses)), average_train_d_losses, label='discriminator training loss')
    plt.plot(np.arange(len(average_train_g_losses)), average_train_g_losses, label='generator training loss')
    plt.plot(np.arange(len(test_d_losses)) * 500, test_d_losses, label='discriminator test loss')
    plt.plot(np.arange(len(test_g_losses)) * 500, test_g_losses, label='generator test loss')
    plt.legend()

In [None]:
def plot_accuracies(results):
    _, _, train_accuracies, _, _, test_accuracies = results
    
    average_train_accuracies = moving_average(train_accuracies, 10)

    plt.figure(figsize=(15, 8))
    plt.plot(np.arange(len(average_train_accuracies)), average_train_accuracies, label='training accuracy')
    plt.plot(np.arange(len(test_accuracies)) * 500, test_accuracies, label='test accuracy')
    plt.legend()

In [None]:
results = run_test('improved-gan-all.pkl')

In [None]:
plot_losses(results)

In [None]:
plot_accuracies(results)