## Parameters

In [1]:
batch_size = 3
seq_length = 4
rnn_size   = 10
vocab_size = 5
save_dir_GAN = 'models_GAN'
vocab_file = 'simple_vocab.pkl'
model = 'lstm'
num_layers = 2

## Generate

In [4]:
def load_vocab(save_dir_GAN, vocab_file):
    '''Load vocabulary objects.
    
    Args:
        save_dir_GAN:  Directory containing vocab_files.
        vocab_file: Vocab file to use.
    '''
    with open(os.path.join(save_dir_GAN, vocab_file)) as f:
            chars, vocab = cPickle.load(f)
    return chars, vocab

# Return text from indices
def map_to_text(indices, chars):
    text = []
    for index in indices:
        text.append(chars[index])
    return text

In [5]:
import tensorflow as tf
import os
import cPickle
from tensorflow.contrib.distributions import Categorical
tf.reset_default_graph()

def generate():    
    # Initial indices
    batch_indices = tf.constant(0, shape=[batch_size])

    # RNN
    outputs, indices = [], []
    cell = tf.nn.rnn_cell.BasicRNNCell(rnn_size)
    state = cell.zero_state(batch_size, tf.float32)
        
    with tf.variable_scope('rnn'):
        # Embeddings and Logits
        embedding = tf.get_variable('embedding', [vocab_size, rnn_size])
        softmax   = tf.get_variable('softmax', [rnn_size, vocab_size])

        inp = tf.nn.embedding_lookup(embedding, batch_indices)
        for i in xrange(seq_length):
            indices.append(batch_indices)
            if i > 0:
                tf.get_variable_scope().reuse_variables()
            rnn_out, state = cell(inp, state)
            logits_out = tf.matmul(rnn_out, softmax)
            outputs.append(logits_out)
            batch_indices = tf.squeeze(Categorical(logits_out).sample(n=1))
            inp = tf.nn.embedding_lookup(embedding, batch_indices)
  
    return outputs, indices
            

with tf.Session() as sess: 
    outputs, indices = generate()
    init_op = tf.initialize_all_variables()
    sess.run(init_op)

    chars, vocab = load_vocab(save_dir_GAN, vocab_file)
    
    for i, output in enumerate(outputs):
        print 'Iteration %d'%i
        print sess.run(output),'\n'
    
    for i, batch_indices in enumerate(indices):
        print 'Iteration %d'%i
        indices_eval = sess.run(batch_indices)
        print map_to_text(indices_eval, chars),'\n'
    
#     for line in indices:
#         print ''.join(line)

Iteration 0
[[-0.09419005  0.56803495  0.47307166  0.50544196 -0.40804175]
 [-0.09419005  0.56803495  0.47307166  0.50544196 -0.40804175]
 [-0.09419005  0.56803495  0.47307166  0.50544196 -0.40804175]] 

Iteration 1
[[ 0.32359368 -0.10683148  0.78633469  0.33877015  0.06092686]
 [-0.62299073 -0.01256511 -0.08685041 -0.06053785  0.04928803]
 [-0.27598929  0.52696258  0.57884568  0.35097671 -0.70507175]] 

Iteration 2
[[-0.59284848  0.75424421  0.20050494  0.55502558 -0.43934166]
 [-0.59284848  0.75424421  0.20050494  0.55502558 -0.43934166]
 [ 0.29675144 -0.01560468  0.94438457  0.32209209  0.04407423]] 

Iteration 3
[[ 0.00524406  0.37031507  0.54975694  0.40046287 -0.64646858]
 [-0.45774862  0.02466834 -0.08280412 -0.061278    0.13488948]
 [ 0.12253881 -0.29317296  0.64915192  0.31881556 -0.01258015]] 

Iteration 0
['b', 'b', 'b'] 

Iteration 1
['b', ' ', ' '] 

Iteration 2
['b', '\n', 'r'] 

Iteration 3
[' ', 'b', 'b'] 



## Train Generator

In [15]:
from tensorflow.python.ops.nn import rnn_cell
from tensorflow.python.ops.nn import rnn
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops.nn import seq2seq 
# TODO: Eliminate depencence on seq2seq

tf.reset_default_graph()

def construct_gan():
    if model == 'rnn':
        cell_gen = rnn_cell.BasicRNNCell(rnn_size)
        cell_dis = rnn_cell.BasicRNNCell(rnn_size)
    elif model == 'gru':
        cell_gen = rnn_cell.GRUCell(rnn_size)
        cell_dis = rnn_cell.GRUCell(rnn_size)
    elif model == 'lstm':
        cell_gen = rnn_cell.BasicLSTMCell(rnn_size, state_is_tuple=False)
        cell_dis = rnn_cell.BasicLSTMCell(rnn_size, state_is_tuple=False)
    else:
        raise NotImplementedError('Model type not supported: {}'
                                  .format(model))

    # Initial indices.
    batch_indices = tf.constant(0, shape=[batch_size])
        
    # Generator Portion of GAN.
    with tf.variable_scope('generator'):
        outputs_gen, logit_sequence = [], []
        cell_gen = rnn_cell.MultiRNNCell([cell_gen] * num_layers)
        state_gen = cell_gen.zero_state(batch_size, tf.float32)

        with tf.variable_scope('rnn'):
            softmax_w = tf.get_variable('softmax_w', [rnn_size, vocab_size])
            softmax_b = tf.get_variable('softmax_b', [vocab_size])
            embedding = tf.get_variable('embedding', [vocab_size, rnn_size])
            inp = tf.nn.embedding_lookup(embedding, batch_indices)
    
            for i in xrange(seq_length):
                if i > 0:
                    tf.get_variable_scope().reuse_variables()
              
                # RNN.
                output_gen, state_gen = cell_gen(inp, state_gen)
                logits_gen = tf.nn.xw_plus_b(output_gen, softmax_w, 
                                             softmax_b)
                
                # Sampling.
                sample_op = tf.stop_gradient(Categorical(
                                            logits_gen).sample(n=1))
                batch_indices = tf.squeeze(sample_op)
                inp = tf.nn.embedding_lookup(embedding, batch_indices)                
                
                # Use Only Logit Sampled.
                one_hot = tf.stop_gradient(tf.one_hot(batch_indices,
                                                      depth = vocab_size,
                                                      dtype = tf.float32))
                logit_gen = one_hot * logits_gen
                logit_sequence.append(logit_gen)
                outputs_gen.append(output_gen)
            
    # Discriminator Portion of GAN. 
    with tf.variable_scope('discriminator'):
        cell_dis = rnn_cell.MultiRNNCell([cell_dis] * num_layers)
        state_dis = cell_dis.zero_state(batch_size, tf.float32)

        with tf.variable_scope('rnn'):
            softmax_w_dis = tf.get_variable('softmax_w', [rnn_size, 2])
            softmax_b_dis = tf.get_variable('softmax_b', [2])
            embedding_dis = tf.get_variable('embedding', [vocab_size, rnn_size])            

            # Input sequence to Discriminator.
            inputs_dis = []
            for logit in logit_sequence:
                inputs_dis.append(tf.matmul(logit, embedding_dis))

            # RNN.
            assert len(inputs_dis) == len(outputs_gen)
            outputs_dis, last_state_dis = seq2seq.rnn_decoder(inputs_dis,
                state_dis, cell_dis, loop_function=None)

            # Predictions.
            probs, logits = [], []
            for output_dis in outputs_dis:
                logit = tf.nn.xw_plus_b(output_dis, softmax_w_dis, softmax_b_dis)
                prob = tf.nn.softmax(logit)
                logits.append(logit)
                probs.append(prob)

    return probs
#         with tf.name_scope('summary'):
#             probs      = tf.pack(probs)
#             probs_real = tf.slice(probs, [0,0,1], [seq_length, batch_size, 1])
#             variable_summaries(probs_real, 'probability of real')

#         final_state_dis = last_state_dis

with tf.Session() as sess:
    probs = construct_gan()
    init_op = tf.initialize_all_variables()
    sess.run(init_op)

    for i, prob in enumerate(probs):
        print 'Iteration %d'%i
        print sess.run(prob),'\n'



Iteration 0
[[ 0.95938396  0.04061598]
 [ 0.95938396  0.04061598]
 [ 0.96078378  0.0392162 ]] 

Iteration 1
[[ 0.96151656  0.03848345]
 [ 0.9606083   0.03939169]
 [ 0.95846027  0.04153975]] 

Iteration 2
[[ 0.9603138   0.0396862 ]
 [ 0.96201485  0.03798509]
 [ 0.95925593  0.04074399]] 

Iteration 3
[[ 0.9602375   0.03976251]
 [ 0.95982772  0.0401723 ]
 [ 0.96159273  0.03840725]] 



## Random

In [None]:
indices = [['0','1','2'],['0','1','2'],['0','1','2']]

for line in indices:
    print ''.join(line)

In [None]:
def variable_summaries(var, name):
    '''Attach a lot of summaries to a Tensor.'''
    mean = tf.reduce_mean(var)
    tf.scalar_summary('mean/' + name, mean)
    with tf.name_scope('stddev'):
        stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean)))
    tf.scalar_summary('sttdev/' + name, stddev)
    tf.scalar_summary('max/' + name, tf.reduce_max(var))
    tf.scalar_summary('min/' + name, tf.reduce_min(var))
    tf.histogram_summary(name, var)