## True State VAE

Creating a simple VAE model to learn how to encode-decode true states

In [1]:
!pip install tensorflow-probability
from IPython import display

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
import tensorflow_probability as tfp
import time

from ray.rllib.offline.json_reader import JsonReader
import numpy_indexed as npi
import pandas as pd
from true_state_viewer import TrueStateTreeGraphViz, display_tree_pairs

You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [17]:
state_length = 24

# Code from https://www.tensorflow.org/tutorials/generative/cvae
class StateVAE(tf.keras.Model):
  """Convolutional variational autoencoder."""

  def __init__(self, latent_dim):
    super(StateVAE, self).__init__()
    self.latent_dim = latent_dim
    self.encoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(state_length)),
            tf.keras.layers.Dense(16, activation="relu"),
            tf.keras.layers.Dense(16, activation="relu"),
            tf.keras.layers.Dense(latent_dim + latent_dim)
        ]
    )

    self.decoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(16, activation="relu"),
            tf.keras.layers.Dense(16, activation="relu"),
            tf.keras.layers.Dense(state_length)
        ]
    )

#   @tf.function
#   def get_dist(self, x):
#     mean, logvar = tf.split(self.dist_fc(x), num_or_size_splits=2, axis=1)
#     return mean, logvar

  @tf.function
  def sample(self, eps=None):
    if eps is None:
        eps = tf.random.normal(shape=(100, self.latent_dim))
    return self.decode(eps, apply_sigmoid=True)
  
  @tf.function
  def encode(self, x):
    return tf.split(self.encoder(x), num_or_size_splits=2, axis=1)

  @tf.function
  def reparameterize(self, mean, logvar):
#     eps = tf.random.normal(shape=mean.shape)
    eps = tf.random.normal(shape=tf.shape(mean))
    
    return eps * tf.exp(logvar * .5) + mean

  @tf.function
  def decode(self, z, apply_sigmoid=False):
    logits = self.decoder(z)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs
    return logits

  @tf.function
  def get_oh_output(self, predictions):
    pred_oh = tf.one_hot(tf.argmax(tf.nn.softmax(tf.reshape(predictions,(-1,2,3))),axis=2),depth=3)
    return pred_oh
    
  def call(self, inputs, training):
    mean, logvar = self.encode(inputs)
    z = self.reparameterize(mean, logvar)
    x_logit = self.decode(z)
    return x_logit


In [18]:


def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)


# TODO: make this a combination of cat cross entropy
def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)

    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1])

    #   THIS SHOULD WORK BETTER....
    #   as suggested in https://medium.com/p/53eefdfdbcc7
    #   acc = 0
    #   for i in range(26):
    #     cross_ent = tf.nn.softmax_cross_entropy_with_logits(logits=x_logit[:,i*3:(i*3)+3], labels=x[:,i*3:(i*3)+3])
    #     acc += cross_ent#tf.reduce_sum(cross_ent, axis=[1])
    #   logpx_z = -acc

    #   logpx_z = logpx_z/26

    logpz = log_normal_pdf(z, 0., 0.)
    logqz_x = log_normal_pdf(z, mean, logvar)
    return -tf.reduce_mean(logpx_z + logpz - logqz_x)


@tf.function
def train_step(model, x, optimizer):
    """Executes one training step and returns the loss.

    This function computes the loss and gradients, and uses the latter to
    update the model's parameters.
    """
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))


In [19]:
epochs = 10000
# set the dimensionality of the latent space to a plane for visualization later
latent_dim = 8
# num_examples_to_generate = 16

# # keeping the random vector constant for generation (prediction) so
# # it will be easier to see the improvement.
# random_vector_for_generation = tf.random.normal(
#     shape=[num_examples_to_generate, latent_dim])
VAE_model = StateVAE(latent_dim)

train_test_split = 0.95
batch_size = 64

data_path = 'logs/APPO/TrueStates_200_1000_Meander_small_4_bit/data'

states = np.load(data_path + '/states.npy')
afterstates = np.load(data_path + '/afterstates.npy')
X = np.concatenate([states, afterstates])
X = [tuple(row) for row in X]
X = np.unique(X, axis=0)
X.shape

(991, 24)

In [20]:
train_dataset = (tf.data.Dataset.from_tensor_slices(X).shuffle(X.shape[0]).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices(X).shuffle(X.shape[0]).batch(batch_size))

In [21]:
optimizer = tf.keras.optimizers.Adam(5e-5)
optimizer.build(VAE_model.trainable_variables)

for epoch in range(1, epochs + 1):
    start_time = time.time()
    for train_x in train_dataset:
        train_x = tf.cast(train_x,tf.float32)
        train_step(VAE_model, train_x, optimizer)
    end_time = time.time()

    loss = tf.keras.metrics.Mean()
    for test_x in test_dataset:
        test_x = tf.cast(test_x,tf.float32)
        loss(compute_loss(VAE_model, test_x))
    elbo = -loss.result()
    if epoch % 20 == 0:
        display.clear_output(wait=False)
        print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'
            .format(epoch, elbo, end_time - start_time))
    #   generate_and_save_images(model, epoch, test_sample)

Epoch: 8080, Test set ELBO: -7.877871036529541, time elapse for current epoch: 0.056746482849121094


KeyboardInterrupt: 

In [22]:
encoding = VAE_model.encode(X)
encoding = VAE_model.reparameterize(encoding[0], encoding[1])
decoding = VAE_model.decode(encoding)

decoding = np.array(decoding>0, dtype=np.float32)

corrects = 0

for i in range(X.shape[0]):
    result = decoding[i,:] - X[i,:]
    corrects += np.abs(result).sum() == 0
corrects/X.shape[0]

0.0020181634712411706

In [16]:
X.shape[0]

991

In [None]:
def get_state_pred_pair(model, state):
    state = tf.cast(state, tf.float32)
    mean, logvar = model.encode(state)
    z = model.reparameterize(mean, logvar)
    predictions = model.sample(z)

    state_oh = tf.reshape(state,(-1,2,3)) #tf.one_hot(tf.argmax(tf.reshape(state,(-1,3)),axis=1),depth=3)
    pred_oh = model.get_oh_output(predictions)#tf.one_hot(tf.argmax(tf.nn.softmax(tf.reshape(predictions,(-1,2,3))),axis=2),depth=3)
    #   print(f"{state_oh.shape}, {pred_oh.shape}")
    return state_oh, pred_oh

In [None]:
get_state_pred_pair(VAE_model, X[0:1,:])

In [None]:
# test_dataset = (tf.data.Dataset.from_tensor_slices((dataset.iloc[:200,1:].values)).shuffle(train_size).batch(batch_size))
#train_dataset2 = (tf.data.Dataset.from_tensor_slices((train_df.iloc[:,1:].values)).shuffle(train_size).batch(1))

total_matches = 0
total = 0
nodes = 0
sum_diffs_sqrd = 0
state_pred_pairs = []
state_pred_pair_tree_vis = []
for test_x in test_dataset:
    total +=1
    state_oh, pred_oh = get_state_pred_pair(VAE_model, test_x)
    
    print(state_oh)
    
    state_pred_pairs.append([state_oh, pred_oh])
    state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(state_oh), TrueStateTreeGraphViz(pred_oh)])
    
    diffs = np.rint(state_oh.numpy()) - np.rint(pred_oh.numpy())
#     diffs = get_state_diff(VAE_model,test_x)
    nodes += len(diffs.flatten())
    diffs_sqrd = np.sum(diffs*diffs)
    sum_diffs_sqrd += diffs_sqrd
    if not diffs_sqrd >0:
#       print(diffs)
#     else:
      total_matches += 1
#       print("Match")
#       print(diffs)

print(f"accuracy = {total_matches}/{total} = {total_matches/total}, \nmean of squared diffs = {sum_diffs_sqrd}/{nodes}={sum_diffs_sqrd/nodes}\npercentage wrong = ({sum_diffs_sqrd}/{2})/({nodes}/{3})={(sum_diffs_sqrd/2)/(nodes/3)}")
display_tree_pairs(state_pred_pair_tree_vis)


In [None]:
# true_state_model.build((1,78))
true_state_model(np.ones((1,78)))
true_state_model.save('models/trueStateVAE_11_L8',overwrite=True)# best one is L8 3 or 4 or 6 or 7 (78%).. need to check
# m2 = Tr
# # true_state_model.save('models/trueStateVAE_1')
# true_state_model.compile()
# # true_state_model._set_inputs(test_x)
# true_state_model.fit(dataset, epochs=1)
# tf.keras.models.save_model(
#     true_state_model,
#     'models/trueStateVAE_1',
#     overwrite=False,
#     include_optimizer=True
# )

In [None]:
m2 = tf.keras.models.load_model(
    'models/trueStateVAE_11_L8',
)

In [None]:


# m2.sample()
total_matches = 0
total = 0
nodes = 0
sum_diffs_sqrd = 0
state_pred_pairs = []
state_pred_pair_tree_vis = []
for test_x in test_dataset:
    total +=1
    state_oh, pred_oh = get_state_pred_pair(m2, test_x)
    
    state_pred_pairs.append([state_oh, pred_oh])
    state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(state_oh), TrueStateTreeGraphViz(pred_oh)])
    
    diffs = np.rint(state_oh.numpy()) - np.rint(pred_oh.numpy())
#     diffs = get_state_diff(true_state_model,test_x)
    nodes += len(diffs.flatten())
    diffs_sqrd = np.sum(diffs*diffs)
    sum_diffs_sqrd += diffs_sqrd
    if not diffs_sqrd >0:
#       print(diffs)
#     else:
      total_matches += 1
#       print("Match")
#       print(diffs)

print(f"accuracy = {total_matches}/{total} = {total_matches/total}, \nmean of squared diffs = {sum_diffs_sqrd}/{nodes}={sum_diffs_sqrd/nodes}\npercentage wrong = ({sum_diffs_sqrd}/{2})/({nodes}/{3})={(sum_diffs_sqrd/2)/(nodes/3)}")
display_tree_pairs(state_pred_pair_tree_vis)

In [None]:
m2.encoder.summary()