In [None]:
%pylab inline
import tensorflow as tf
from keras.models import Model
from keras.layers import Input, GRU, Dense, Lambda
from keras import backend as K
import tqdm

def get_data(data_file, batch_size, length):
  inds = np.load(data_file)[:, :length]
  inds[inds==0] = ord('\t')
  start = ord('\t') * np.ones((inds.shape[0], 1))
  inds = np.concatenate([start, inds], axis=-1).astype(np.int32)
  length += 1
  vocab_size = inds.max() + 1
  I = np.eye(vocab_size, dtype=np.float32)
  def sample():
    i = np.random.randint(inds.shape[0])
    Xi = I[inds[i]]
    return Xi
  X = tf.py_func(sample, [], [tf.float32])[0]
  X.set_shape((length, vocab_size))
  batch = tf.train.shuffle_batch(
    [X], batch_size,
    capacity=batch_size*10,
    min_after_dequeue=0,
    num_threads=1)
  return batch

def Sample():
  def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon
  return Lambda(sampling, name='sample')

def encoder(sequence, latent_dim):
  seq = Input(tensor=sequence)
  _, h = GRU(latent_dim, return_state=True, activation='tanh')(seq)
  mu, logsig = [Dense(latent_dim)(h) for _ in (0, 1)]
  z = Sample()([mu, logsig])
  return Model(inputs=seq, outputs=[z, mu, logsig])

def decoder(sequence, state, latent_dim):
  vocab_size = int(sequence.shape[-1])
  seq = Input(tensor=sequence)
  state = Input(tensor=state)
  z = Dense(latent_dim)(state)
  x, _ = GRU(latent_dim, return_sequences=True, return_state=True, 
    activation='tanh')(seq, initial_state=z)
  x = Dense(vocab_size)(x)
  return Model(inputs=[seq, state], outputs=x)

def Switch():
  def cond(args):
    A, B = args
    return tf.cond(K.learning_phase(), lambda: A, lambda: B)
  return Lambda(cond, name='switch')

def recurrent_VAE(sequence, latent_dim):
  enc = encoder(sequence, latent_dim)
  state, mu, logsig = enc.outputs
  # keras Model.__call__(X) connects model.outputs to given X, but does not 
  # connect state variabels (such as GRU init. state and batchnorm params) to X, 
  # so we explicitly build the model with a "switch" input that uses placeholder
  # inputs when learning_phase == False (but alas, also always expects dummy PHs)
  sequence_ph = Input(shape=(None, int(sequence.shape[-1])))
  state_ph = Input(shape=state.shape[1:])
  placeholders = [sequence_ph, state_ph]
  dec_seq = Switch()([sequence, sequence_ph])
  dec_state = Switch()([state, state_ph])
  dec = decoder(dec_seq, dec_state, latent_dim)
  return Model(inputs=enc.inputs+dec.inputs+placeholders, 
               outputs=enc.outputs+dec.outputs)

def xentropy(X, Xhat):
  xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
    labels=tf.argmax(X, axis=-1), logits=Xhat)
  return tf.reduce_mean(xent)

def KL_divergence(mu, logsig):
  KL = 1 + logsig - mu**2 - tf.exp(logsig)
  return -tf.reduce_mean(KL)

def regularize(weights):
  return tf.global_norm(weights) / len(weights)

def matrix_to_string(m):
  return ''.join([chr(i) for i in np.argmax(m, axis=-1)])

def test(sess, Xhat, state, sequence, desired_length):
  vocab_size = int(sequence.shape[-1])
  latent_dim = int(state.shape[-1])
  I = np.eye(vocab_size)
  seed = I[[ord('\t')]]
  init_state = np.random.normal(size=latent_dim)
  for i in range(desired_length):
    feed_dict = {
      K.learning_phase(): False,
      sequence: [seed],
      state: [init_state]}
    result = sess.run(Xhat, feed_dict=feed_dict)
    new_char = I[[result[0, -1].argmax()]]
    seed = np.concatenate((seed, new_char))
  return matrix_to_string(seed)
  
def train(model):
  X = model.inputs[0]
  X_shifted = tf.manip.roll(X, shift=-1, axis=1)
  sequence_ph, state_ph = model.inputs[-2:]
  state, mu, logsig, Xhat = model.outputs
  print('got data')
  XE = xentropy(X_shifted, Xhat)
  KL = KL_divergence(mu, logsig)
  RG = 1e-4*regularize(model.trainable_weights)
  L = XE + KL + RG
  print('got objective')
  adam = tf.train.AdamOptimizer()
  grad, var = zip(*adam.compute_gradients(
    L, var_list=model.trainable_weights))
  grad, norm = tf.clip_by_global_norm(grad, clip_norm=5)
  opt = adam.apply_gradients(zip(grad, var)) 
  print('got optimizer')
  dummy_dict = {
    K.learning_phase(): True,
    sequence_ph: np.zeros((1, 1, int(sequence_ph.shape[-1]))),
    state_ph: np.zeros((1, int(state_ph.shape[-1])))}
  coord = tf.train.Coordinator()
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    tf.train.start_queue_runners(sess=sess, coord=coord)
    print('initialized variables')
    while 1:
      losses = []
      print(flush=True)
      for i in tqdm.trange(1000):
        losses.append(sess.run([opt, XE, KL, RG, norm], feed_dict=dummy_dict)[1:])
      print(flush=True)
      print('xent: {}\nKL:   {}\nRG:   {}\nnorm: {}'.format(*np.mean(losses, axis=0)))
      print()
      print(test(sess, Xhat, state_ph, sequence_ph, int(X.shape[1])-1))

batch_size = 1024
max_length = 64
latent_dim = 512

X = get_data('./tweets.npy', batch_size, max_length)
model = recurrent_VAE(X, latent_dim)
train(model)