In [None]:
%pylab inline
import tensorflow as tf
from keras.models import Model
from keras.layers import Input, Dense, Lambda, CuDNNGRU, Conv1D, Activation
from keras import backend as K
import glob
import time
import tqdm

def LogNormal():
  def func(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon
  return Lambda(func, name='log_normal')

def encoder(sequence, latent_dim):
  seq = Input(tensor=sequence)
  x = Conv1D(64, 9, activation='relu', padding='same')(seq)
  _, h = CuDNNGRU(latent_dim, return_state=True)(x)
  mu, logsig = [Dense(latent_dim)(h) for _ in (0, 1)]
  z = LogNormal()([mu, logsig])
  return Model(inputs=seq, outputs=[z, mu, logsig])

def decoder(sequence, state, latent_dim):
  vocab_size = int(sequence.shape[-1])
  seq = Input(tensor=sequence)
  state = Input(tensor=state)
  x, _ = CuDNNGRU(latent_dim, return_sequences=True, return_state=True)(seq, initial_state=state)
  x = Dense(vocab_size)(x)
  return Model(inputs=[seq, state], outputs=x)

def Switch():
  def func(args):
    A, B = args
    return tf.cond(K.learning_phase(), lambda: A, lambda: B)
  return Lambda(func, name='switch')

def recurrent_VAE(sequence, latent_dim):
  enc = encoder(sequence, latent_dim)
  state, mu, logsig = enc.outputs
  # keras Model.__call__(X) connects model.outputs to given X, but does not necessarily
  # connect stateful variables such as GRU initial state and batchnorm moving average to X, 
  # so we explicitly build the model with a "switch" input that uses placeholder
  # inputs when learning_phase == False (but alas, also always expects dummy PHs)
  sequence_ph = Input(shape=(None, int(sequence.shape[-1])))
  state_ph = Input(shape=state.shape[1:])
  placeholders = [sequence_ph, state_ph]
  dec_seq = Switch()([sequence, sequence_ph])
  dec_state = Switch()([state, state_ph])
  dec = decoder(dec_seq, dec_state, latent_dim)
  return Model(inputs=enc.inputs+dec.inputs+placeholders, 
               outputs=enc.outputs+dec.outputs)

def xentropy(X, Xhat):
  xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
    labels=tf.argmax(X, axis=-1), logits=Xhat)
  return tf.reduce_mean(xent)

def KL_divergence(mu, logsig):
  KL = -0.5 * tf.reduce_sum(1 + logsig - mu**2 - tf.exp(logsig), axis=1)
  return tf.reduce_mean(KL)

def matrix_to_string(m):
  amax = np.argmax(m, axis=-1)
  return ''.join([chr(i) for i in amax])

def test(sess, Xhat, state, sequence, desired_length):
  vocab_size = int(sequence.shape[-1])
  latent_dim = int(state.shape[-1])
  I = np.eye(vocab_size)
  seed = I[[0]]
  init_state = np.random.normal(size=latent_dim)
  for i in range(desired_length):
    feed_dict = {
      K.learning_phase(): False,
      sequence: [seed],
      state: [init_state]}
    result = sess.run(Xhat, feed_dict=feed_dict)
    new_char = I[[result[0, -1].argmax()]]
    seed = np.concatenate((seed, new_char))
  return matrix_to_string(seed)

def get_text_dataset(pattern, batch_size, length):
  txt_files = glob.glob(pattern)
  texts = []
  for i, fname in enumerate(txt_files):
    with open(fname, 'rb') as f:
      string = f.read().decode('ascii', 'ignore')
      string = ' '.join(string.split())
      encoded = string.encode()
      texts.append(np.frombuffer(encoded, dtype=np.uint8))
  texts = np.concatenate(texts)
  texts = texts[texts < 128]
  def sample():
    while 1:
      maxind = texts.size - length
      start = np.random.choice(maxind)
      stop = start + length
      sample = texts[start+1:stop]
      sample = np.concatenate([[0], sample], axis=0)
      yield sample
  dataset = tf.data.Dataset.from_generator(sample, tf.uint8, [max_length])
  dataset = dataset.batch(batch_size)
  dataset = dataset.map(lambda batch: tf.one_hot(batch, 128))
  dataset = dataset.prefetch(1)
  x = dataset.make_one_shot_iterator().get_next()
  return x
  
def train(model, clip_norm, lambda_r):
  X = model.inputs[0]
  X_shifted = tf.concat([X[:, 1:], X[:, :1]], axis=1)
  sequence_ph, state_ph = model.inputs[-2:]
  state, mu, logsig, Xhat = model.outputs
  print('got inputs')
  XE = xentropy(X_shifted, Xhat)
  KL = KL_divergence(mu, logsig)
  RG = lambda_r * tf.global_norm(model.trainable_weights)
  L = XE + KL + RG
  print('got objective')
  adam = tf.train.AdamOptimizer()
  grad, var = zip(*adam.compute_gradients(
    L, var_list=model.trainable_weights))
  grad, norm = tf.clip_by_global_norm(grad, clip_norm=clip_norm)
  opt = adam.apply_gradients(zip(grad, var)) 
  print('got optimizer')
  dummy_dict = {
    K.learning_phase(): True,
    sequence_ph: np.zeros((1, 1, int(sequence_ph.shape[-1]))),
    state_ph: np.zeros((1, int(state_ph.shape[-1])))}
  coord = tf.train.Coordinator()
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print('initialized variables')
    iteration = 0
    while 1:
      losses = []
      progbar = tqdm.tqdm_notebook(range(100))
      for i in progbar:
        losses.append(sess.run([opt, L, XE, KL, RG, norm], feed_dict=dummy_dict)[1:])
        progbar.set_description('L={:.3f}'.format(losses[-1][0]))
        iteration += 1
      print('iter: {}'.format(iteration))
      print('L:    {}\nXE:   {}\nKL:   {}\nRG:   {}\ngrad: {}'.format(*np.mean(losses, axis=0)))
      print('sample:\n\t'+test(sess, Xhat, state_ph, sequence_ph, int(X.shape[1])-1))

batch_size = 4096
latent_dim = 512
max_length = 64
clip_norm = 5.0
lambda_r = 0.0

X = get_text_dataset('/Volumes/1TBSSD/classic-literature-in-ascii/NONFICTION/*.txt', batch_size, max_length)
model = recurrent_VAE(X, latent_dim)
train(model, clip_norm, lambda_r)