# Reinforcement Learning GAN
Employ REINFORCE for training the Generator in a GAN formuluation.  As before, the reward to the Generator $G$ is simply the probability of real at each time point, $p_t$, assigned by the Discriminator $D$.

Classic GAN minimization objective,

$$\text{max}_{G_{\theta}} \text{min}_{D_{\theta}} \left[ \text{log}(D(x)) + \text{log}(1. - D(G(z)) \right]$$

In [49]:
import tensorflow as tf
import numpy as np
from tensorflow.contrib.distributions import Categorical

rnn_gen_size = 3
rnn_dis_size = 2
time_steps = 4
batch_size = 1
num_trials = 1000
print_every = 100
lr = 0.1
decay = 0.9

## Real Data.
Create a simple data distribution with sequential structure for testing the GAN.

In [48]:
num_examples = 1000

# Real Sequence (num_examples, time_steps, inp_dim)
real_sequences = []
real_seq = np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.],[1.,1.,1.]])

for i in xrange(num_examples):
    real_sequences.append(real_seq)
real_sequences = np.asarray(real_sequences)

# Real Labels (num_examples, time_steps)
real_labels = np.ones([num_examples, time_steps], dtype=np.int)


## Generator and Discriminator.

In [58]:
def generator():
    '''Define the Generator graph.'''
    # TODO: Generalize for random input.
    # batch_indices = gen_z
    batch_indices = tf.constant([[1., 0., 0.]], dtype=tf.float32)   
    
    with tf.variable_scope('gen'):
        cell_gen = tf.nn.rnn_cell.BasicRNNCell(rnn_gen_size)
        state_gen = cell_gen.zero_state(batch_size, tf.float32)

        with tf.variable_scope('rnn'):
            sequence = []

            for t in xrange(time_steps):        
                if t > 0:
                    tf.get_variable_scope().reuse_variables()
                    
                rnn_in = batch_indices
                rnn_out, state_gen = cell_gen(rnn_in, state_gen)
                log_probs = tf.nn.log_softmax(rnn_out)

                index = tf.contrib.bayesflow.stochastic_graph.DistributionTensor(
                    tf.contrib.distributions.Categorical,
                    logits=log_probs)
                batch_indices = tf.one_hot(index, rnn_gen_size, dtype=tf.float32)
                sequence.append(batch_indices)
    return tf.pack(sequence, axis=1)


def discriminator(sequence):
    '''Define the Discriminator graph.'''
    sequence = tf.unpack(sequence, axis=1)
    with tf.variable_scope('dis'):    
        cell_dis = tf.nn.rnn_cell.BasicRNNCell(rnn_dis_size)
        state_dis = cell_dis.zero_state(batch_size, tf.float32)

        with tf.variable_scope('rnn'):
            predictions = []

            for t, inp in enumerate(sequence):
                rnn_in = inp
                if t > 0:
                    tf.get_variable_scope().reuse_variables()

                rnn_out, state_dis = cell_dis(rnn_in, state_dis)
                pred = tf.nn.sigmoid(rnn_out)
                predictions.append(pred)
    return predictions

## Training Operations.

In [52]:
def train_generator(predictions):
    '''Train generator via REINFORCE.'''
    rewards = []
    for pred in predictions:
        rewards.append(pred)
    
    # Exponential baseline.
    ema = tf.train.ExponentialMovingAverage(decay = decay)
    rewards_tf = tf.pack(rewards)
    reduced_reward = tf.reduce_mean(rewards_tf)
    maintain_avg_op = ema.apply([reduced_reward])
    baseline = ema.average(reduced_reward)
    
    # Advantage.
    loss = []
    for reward in rewards:
        advantage = reward - baseline
        loss.append(-advantage)

    # Optimizer 
    optimizer = tf.train.AdamOptimizer(lr)
    final_loss = tf.contrib.bayesflow.stochastic_graph.surrogate_loss(loss)
    gen_vars = [v for v in tf.trainable_variables() if v.op.name.startswith('gen')]
    min_op = optimizer.minimize(final_loss, var_list = gen_vars)
    
    # Group operations.
    train_op = tf.group(min_op, maintain_avg_op)
    return final_loss, train_op


def train_discriminator(sequences, labels):
    '''Train Discriminator via cross entropy'''
#     loss = []
#     for seq, lab in zip(sequences, labels):
#         loss.append(tf.contrib.losses.log_loss(seq, lab))
    loss = tf.contrib.losses.log_loss(seq, lab)
    
    final_loss = tf.reduce_sum(loss) / batch_size / seq_length
    dis_vars = [v for v in tf.trainable_variables() if v.op.name.startswith('dis')]
    
    # Optimizer
    optimizer = tf.train.AdamOptimizer(lr)
    train_op = optimizer.minimize(final_loss, var_list = dis_vars)
    return final_loss, train_op

### Check Loss Calculation

In [50]:
logits = tf.constant([[0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5]], dtype=tf.float32)
labels = tf.constant([[0, 0, 0, 0], [0, 0, 0, 0]])

loss = tf.contrib.losses.log_loss(logits, labels)
with tf.Session() as sess:
    print sess.run(loss)


0.693147


## Generative Adversarial Networks Training.

In [59]:
tf.reset_default_graph()


fake_sequence = generator()
predictions = discriminator(fake_sequence)

with tf.Session() as sess:
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    
    for t in xrange(num_trials):
        sess.run(train_op)
        
        assert batch_size == 1
        # TODO: Better batching.
        for n in xrange(num_examples):
            x, y = real_sequences[n], real_labels[n]
            


            
#             gen_loss, gen_train_op = train_generator(preds)
#             dis_loss, dis_train_op = train_discriminator()
            
#             sess.run(dis_train_op)
#             sess.run(gen_train_op)

#             if t % print_every == 0:
#                 print sess.run(gen_loss)
                

#             print sess.run(fake_seq)

(1, 4, 3)


# TODO
* Generalize for random input $z_{gen}$
* Improve batching