### InfoGAN

* generator

10 samples from U(0, 1) + c -> 128 -> relu -> 784 -> sigm -> 28x28

* discriminator

28x28 from data -> flatten to 784 -> 784 -> relu -> 128 -> sigm

In [None]:
import random

import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
(X, y), _ = tf.keras.datasets.fashion_mnist.load_data()
X = X / 128.0

fmnist = X[y == 6][:1000].reshape([-1, 28, 28, 1])

In [None]:
(lambda x: plt.imshow(
    fmnist[x].reshape((28, 28)), cmap='gray'))(0)

In [None]:
def generator(z, c):
    z = tf.concat(values=[z, c], axis=1)
    with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):
        z = tf.layers.dense(z, units=128,
                            activation=tf.nn.relu)
        prob = tf.layers.dense(z, units=784,
                            activation=tf.nn.sigmoid)
    return tf.reshape(prob, shape=(-1, 28, 28, 1))

In [None]:
def discriminator(x):
    with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):
        x = tf.contrib.layers.flatten(x)
        logit = tf.layers.dense(x, units=784,
                            activation=tf.nn.relu)
        prob = tf.layers.dense(logit, units=128,
                            activation=tf.nn.sigmoid)
    return prob

In [None]:
def Q(q):
    with tf.variable_scope("Q", reuse=tf.AUTO_REUSE):
        q = tf.contrib.layers.flatten(q)
        q = tf.layers.dense(q, units=784,
                            activation=tf.nn.relu)
        q = tf.layers.dense(q, units=128,
                            activation=tf.nn.relu)        
        prob = tf.layers.dense(q, units=10,
                            activation=tf.nn.softmax)
    return prob

In [None]:
z = tf.placeholder(
    tf.float32, shape=(None, 10), name='z')
x = tf.placeholder(
    tf.float32, shape=(None, 28, 28, 1), name='x')
c = tf.placeholder(
    tf.float32, shape=(None, 10), name='c')

In [None]:
sample = generator(z, c)
qcx = Q(sample)
dreal = discriminator(x)
dfake = discriminator(sample)

In [None]:
dloss = -tf.reduce_mean(
    tf.log(dreal) + tf.log(1. - dfake))
gloss = -tf.reduce_mean(tf.log(dfake))
qloss = tf.reduce_mean(
    -tf.reduce_sum(tf.log(qcx + 1e-8) * c, 1))

In [None]:
dopt = tf.train.AdamOptimizer(0.001).minimize(dloss)
gopt = tf.train.AdamOptimizer(0.001).minimize(gloss)
qopt = tf.train.AdamOptimizer(0.001).minimize(qloss)

In [None]:
def sample_c(m):
    return np.random.multinomial(1, 10*[0.1], size=m).astype(np.float32)

In [None]:
def sample_z(m, n):
    return np.random.uniform(0., 1., size=(m, n))

In [None]:
sess = tf.Session()
epochs = 10
batch_size = 200
sess.run(tf.global_variables_initializer())

try:
    for e in range(epochs):
        print(f"Epoch: {e+1}/{epochs}")
        for i in range(len(fmnist) // batch_size):
            batch = fmnist[random.sample(range(len(fmnist)), batch_size)]
            z_noise = sample_z(batch_size, 10)
            c_noise = sample_c(batch_size)
            _, dlossc = sess.run([dopt, dloss], feed_dict={x: batch, z: z_noise, c: c_noise})
            _, glossc = sess.run([gopt, gloss], feed_dict={z: z_noise, c: c_noise})
            sess.run([qopt], feed_dict={z: z_noise, c: c_noise})
            if i % 20 == 0:
                print(f"iter: {i}, dloss: {str(round(dlossc, 2))}, gloss: {str(round(glossc, 2))}")
except KeyboardInterrupt:
    print("Stopped")

### Sample latent space

In [None]:
latent = np.random.normal(0, 1, 10).reshape((1, 10))
c_passed = np.array([1, 0, 1, 0, 1, 0, 1, 1, 0, 1]).reshape((1, 10))
decoded = sess.run(sample, feed_dict={z: latent, c: c_passed})
plt.imshow(decoded.reshape((28, 28)), cmap='gray')

In [None]:
sess.close()