In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

In [None]:
CRITIC_ITERS = 1
BATCH_SIZE = 64
ITERS = 100000
TEMPERATURE = 0.001
EPS = 1e-8

In [None]:
real_probs = {
    0: 0.1,
    1: 0.2,
    2: 0.3,
    3: 0.25,
    4: 0.15
}

def create_real_data():
    samples = np.random.choice(
        list(real_probs.keys()),
        p=list(real_probs.values()),
        size=BATCH_SIZE
    )
    return np.identity(len(real_probs))[samples]

def Generator(n_samples):
    with tf.variable_scope('Generator'):
        logits = tf.get_variable('logits', initializer=tf.ones([len(real_probs)]))
        gumbel_dist = tfp.distributions.RelaxedOneHotCategorical(TEMPERATURE, logits=logits)
        probs = tf.nn.softmax(logits)
        outputs = gumbel_dist.sample(n_samples)
        return outputs, probs

def Discriminator(inputs):
    with tf.variable_scope('Discriminator', reuse=tf.AUTO_REUSE):
        return tf.layers.dense(inputs, units=1, activation=tf.nn.sigmoid)

In [None]:
real_data = tf.placeholder(tf.float32, shape=[None, 5])

fake_data, fake_probs = Generator(BATCH_SIZE)
disc_real = Discriminator(real_data)
disc_fake = Discriminator(fake_data)

disc_loss = -tf.reduce_mean(tf.log(disc_real + EPS) + tf.log(1.0 - disc_fake + EPS))
gen_loss = -tf.reduce_mean(tf.log(disc_fake + EPS))

In [None]:
disc_params = tf.trainable_variables('Discriminator')
gen_params = tf.trainable_variables('Generator')

disc_train_op = tf.train.AdamOptimizer(
    learning_rate=1e-3,
).minimize(disc_loss, var_list=disc_params)

gen_train_op = tf.train.AdamOptimizer(
    learning_rate=1e-3,
).minimize(gen_loss, var_list=gen_params)

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [None]:
%%time

for i in range(ITERS+1):
    for _ in range(CRITIC_ITERS):
        fd = {real_data: create_real_data()}
        sess.run(disc_train_op, fd)

    sess.run(gen_train_op)

    if i % 10000 == 0:
        print(i, ' : {:.4f}, {:.4f}, {:.4f}, {:.4f}, {:.4f}'.format(*sess.run(fake_probs)))
