In [None]:
import sys
import os
sys.path.append(os.path.abspath("../../")) # Add path to root for imports

import tensorflow as tf
import numpy as np
import png # Image saving
import math

# Load MNIST data for experiments
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

In [None]:
z_size = 100
batch_size = 128
class_count = 10

## Model

In [None]:
import molanet.models.dcgan as model
import importlib
importlib.reload(model) # Reload for fast prototyping

# Reset variables
tf.reset_default_graph()

# Inputs
z = tf.placeholder(tf.float32, [None, z_size + class_count + 1])
x = tf.placeholder(tf.float32, [None, 32, 32, 1])
y = tf.placeholder(tf.float32, [None, class_count + 1])

# Models
gen_net, w_gen, b_gen = model.dcgan_generator(z, batch_size=batch_size)
disc_real_net, disc_real_logits, w_disc, b_disc = model.dcgan_discriminator(x, class_count + 1)
disc_gen_net, disc_gen_logits, _, _ = model.dcgan_discriminator(gen_net, class_count + 1, reuse=True)

# Loss functions
fake_logits = np.zeros((batch_size, class_count + 1)).astype(np.float32)
fake_logits[:, -1] = 1 # Last class = fake
disc_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real_logits, labels=y))
disc_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_gen_logits, labels=fake_logits))
disc_loss = disc_loss_real + disc_loss_fake
gen_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_gen_logits, labels=y))

trainable_variables = tf.trainable_variables()
disc_vars = [var for var in trainable_variables if var.name.startswith("discriminator")]
gen_vars = [var for var in trainable_variables if var.name.startswith("generator")]

# Optimizers
disc_optim = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5)
gen_optim = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5)

disc_update_step = disc_optim.minimize(disc_loss, var_list=disc_vars)
gen_update_step = gen_optim.minimize(gen_loss, var_list=gen_vars)

## Training

In [None]:
iterations = 500000
sample_directory = "./samples" # Generated samples
model_directory = "./models" # Model

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    for iteration in range(iterations):
        # Random z batch
        zs = np.random.uniform(-1.0, 1.0, size=[batch_size, z_size]).astype(np.float32)
        zs_class = np.zeros((batch_size, class_count + 1))
        for idx in range(batch_size):
            zs_class[idx, np.random.randint(0, class_count)] = 1.0
        zs = np.concatenate([zs, zs_class], axis=1)

        # "Real" input images
        batch = mnist.train.next_batch(batch_size)
        ys = np.zeros((batch_size, class_count + 1)).astype(np.float32)
        # TODO: More intelligent one-hot
        for idx, train_y in enumerate(batch[1]):
            ys[idx, :] = np.concatenate([train_y, [0.0]])

        xs = batch[0]
        xs = (np.reshape(xs, [batch_size, 28, 28, 1]) - 0.5) * 2.0 # Transform into range -1, 1
        xs = np.pad(xs, [[0, 0], [2, 2], [2, 2], [0, 0]], "constant", constant_values=[-1, -1]) # Pad to have correct size (28x28 -> 32x32)

        # Update discriminator
        _, d_loss = sess.run([disc_update_step, disc_loss], feed_dict={z: zs, x: xs, y: ys})

        # Update generator (twice)
        _, g_loss = sess.run([gen_update_step, gen_loss], feed_dict={z: zs, y: zs_class})
        _, g_loss = sess.run([gen_update_step, gen_loss], feed_dict={z: zs, y: zs_class})

        # Print loss
        if iteration % 10 == 0:
            print(f"Loss: generator={g_loss},\tdiscriminator={d_loss}")
            
            # Generate sample images
            z_sample = np.random.uniform(-1.0, 1.0, size=[batch_size, z_size]).astype(np.float32)
            z_sample_class = np.zeros((batch_size, class_count + 1))
            z_sample_class[0:class_count, :-1] = np.eye(class_count)

            z_sample = np.concatenate([z_sample, z_sample_class], axis=1)

            sample_images = sess.run(gen_net, feed_dict={z: z_sample})

            # Save sample images
            if not os.path.exists(sample_directory):
                os.makedirs(sample_directory)

            raw_images = np.reshape(sample_images[0:class_count], [class_count, 32, 32])

            # TODO: tanh is never exactly -1 or 1!!!!!!!!!!!!!!!!!!!!!!!!
            raw_images = (raw_images + 1.0) / 2.0 * 255.0 # Transform to range 0, 255
            save_path = f"{sample_directory}/sample_{iteration}.png"

            # Generate figure
            height = raw_images.shape[1]
            width = raw_images.shape[2]
            count_y = int(math.sqrt(class_count))
            count_x = int(math.ceil(math.sqrt(class_count)))
            sample_figure = np.zeros((count_y * height, count_x * width))
            for idx, image in enumerate(raw_images):
                i = idx % count_x
                j = idx // count_x
                sample_figure[j * height:(j + 1) * height, i * width:(i + 1) * width] = image
            
            with open(save_path, 'w+b') as f:
                png.Writer(sample_figure.shape[1], sample_figure.shape[0], greyscale=True).write(f, sample_figure)

        # Save model
        if iteration % 1000 == 0:
            if not os.path.exists(model_directory):
                os.makedirs(model_directory)
        
            saver.save(sess, f"{model_directory}/model-{iteration}.cptk")
            print(f"Saved model from iteration {iteration}")