## True State VAE

Creating a simple VAE model to learn how to encode-decode true states

In [141]:
!pip install tensorflow-probability
from IPython import display

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
import tensorflow_probability as tfp
import time

from ray.rllib.offline.json_reader import JsonReader
import numpy_indexed as npi
import pandas as pd
from true_state_viewer import TrueStateTreeGraphViz, display_tree_pairs

You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [142]:
# Code from https://www.tensorflow.org/tutorials/generative/cvae
class TrueStateVAE(tf.keras.Model):
  """Convolutional variational autoencoder."""

  def __init__(self, latent_dim):
    super(TrueStateVAE, self).__init__()
    self.latent_dim = latent_dim
    self.encoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(78)),
#             tf.keras.layers.Dense(512, activation="relu"),
#             tf.keras.layers.Dense(512, activation="relu"),
#             tf.keras.layers.Dense(8192, activation="relu"),
#             tf.keras.layers.Dense(16384, activation="relu"),
#             RUN 7
#             tf.keras.layers.Dense(10000, activation="relu"),
#             tf.keras.layers.Dense(6000, activation="relu"),
#               RUN 8 (bs32) RUN 9 (BS64)
#             tf.keras.layers.Dense(15000, activation="relu"),
#             tf.keras.layers.Dense(10000, activation="relu"),
#             RUN 10 (bs64, L8)
            tf.keras.layers.Dense(512, activation="relu"),
            tf.keras.layers.Dense(512, activation="relu"),
            tf.keras.layers.Dense(256, activation="relu"),
            tf.keras.layers.Dense(256, activation="relu"),
            tf.keras.layers.Dense(128, activation="relu"),
            tf.keras.layers.Dense(128, activation="relu"),
            tf.keras.layers.Dense(latent_dim + latent_dim)
        ]
    )
#     self.dist_fc = tf.keras.Sequential(
#         [
#             tf.keras.layers.Dense(latent_dim + latent_dim)
#         ]
#     )

    self.decoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),

        
            tf.keras.layers.Dense(128, activation="relu"),
            tf.keras.layers.Dense(256, activation="relu"),
            tf.keras.layers.Dense(256, activation="relu"),
            tf.keras.layers.Dense(512, activation="relu"),
            tf.keras.layers.Dense(512, activation="relu"),
            tf.keras.layers.Dense(78)
        ]
    )

#   @tf.function
#   def get_dist(self, x):
#     mean, logvar = tf.split(self.dist_fc(x), num_or_size_splits=2, axis=1)
#     return mean, logvar

  @tf.function
  def sample(self, eps=None):
    if eps is None:
      eps = tf.random.normal(shape=(100, self.latent_dim))
    return self.decode(eps, apply_sigmoid=True)
  
  @tf.function
  def encode(self, x):
    return tf.split(self.encoder(x), num_or_size_splits=2, axis=1)

  @tf.function
  def reparameterize(self, mean, logvar):
#     eps = tf.random.normal(shape=mean.shape)
    eps = tf.random.normal(shape=tf.shape(mean))
    
    return eps * tf.exp(logvar * .5) + mean

  @tf.function
  def decode(self, z, apply_sigmoid=False):
    logits = self.decoder(z)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs
    return logits

  @tf.function
  def get_oh_output(self, predictions):
    pred_oh = tf.one_hot(tf.argmax(tf.nn.softmax(tf.reshape(predictions,(-1,2,3))),axis=2),depth=3)
    return pred_oh
    
  def call(self, inputs, training):
    mean, logvar = self.encode(inputs)
    z = self.reparameterize(mean, logvar)
    x_logit = self.decode(z)
    return x_logit
    

In [162]:
optimizer = tf.keras.optimizers.Adam(.25e-4)


def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)


# TODO: make this a combination of cat cross entropy
def compute_loss(model, x):
  mean, logvar = model.encode(x)
  z = model.reparameterize(mean, logvar)
  x_logit = model.decode(z)
  
  cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
  logpx_z = -tf.reduce_sum(cross_ent, axis=[1])

#   THIS SHOULD WORK BETTER....
#   as suggested in https://medium.com/p/53eefdfdbcc7
#   acc = 0
#   for i in range(26):
#     cross_ent = tf.nn.softmax_cross_entropy_with_logits(logits=x_logit[:,i*3:(i*3)+3], labels=x[:,i*3:(i*3)+3])
#     acc += cross_ent#tf.reduce_sum(cross_ent, axis=[1])
#   logpx_z = -acc

#   logpx_z = logpx_z/26


  logpz = log_normal_pdf(z, 0., 0.)
  logqz_x = log_normal_pdf(z, mean, logvar)
  return -tf.reduce_mean(logpx_z + logpz - logqz_x)


@tf.function
def train_step(model, x, optimizer):
  """Executes one training step and returns the loss.

  This function computes the loss and gradients, and uses the latter to
  update the model's parameters.
  """
  with tf.GradientTape() as tape:
    loss = compute_loss(model, x)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  and should_run_async(code)


In [163]:
epochs = 300
# set the dimensionality of the latent space to a plane for visualization later
latent_dim = 16
# latent_dim = 5
# num_examples_to_generate = 16

# # keeping the random vector constant for generation (prediction) so
# # it will be easier to see the improvement.
# random_vector_for_generation = tf.random.normal(
#     shape=[num_examples_to_generate, latent_dim])
true_state_model = TrueStateVAE(latent_dim)

# train_size = 4445
train_test_split = 0.95
batch_size = 64

In [164]:
true_state_model.encoder.summary()

Model: "sequential_28"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_138 (Dense)           (None, 512)               40448     
                                                                 
 dense_139 (Dense)           (None, 512)               262656    
                                                                 
 dense_140 (Dense)           (None, 256)               131328    
                                                                 
 dense_141 (Dense)           (None, 256)               65792     
                                                                 
 dense_142 (Dense)           (None, 128)               32896     
                                                                 
 dense_143 (Dense)           (None, 128)               16512     
                                                                 
 dense_144 (Dense)           (None, 32)              

In [165]:
meander = pd.read_csv('csv_data/all_true_states.csv')
#bline = pd.read_csv('csv_data/TrueStates_1221_4000_B_Line.csv')
#badbluemeander = pd.read_csv('csv_data/TrueStates_200_4000_Meander_badblue.csv')

#dataset = pd.concat([meander, bline, badbluemeander], ignore_index=True)
#dataset = dataset.drop_duplicates()

dataset = meander

print(f"number of rows = {meander.shape[0]}")
#print(f"number of rows = {bline.shape[0]}")
#print(f"number of rows = {badbluemeander.shape[0]}")
#print(f"number of rows = {dataset.shape[0]}")
# dataset=bline

number of rows = 4515


In [147]:
dataset = dataset.drop('Unnamed: 0', axis=1)
dataset

  and should_run_async(code)


Unnamed: 0,Unnamed: 0.1,0_unknown,0_known,0_scanned,0_none,0_user,0_privileged,1_unknown,1_known,1_scanned,...,11_scanned,11_none,11_user,11_privileged,12_unknown,12_known,12_scanned,12_none,12_user,12_privileged
0,0,1,0,0,1,0,0,1,0,0,...,0,1,0,0,1,0,0,1,0,0
1,1,1,0,0,1,0,0,1,0,0,...,0,1,0,0,0,1,0,1,0,0
2,2,1,0,0,1,0,0,1,0,0,...,1,1,0,0,0,1,0,1,0,0
3,3,1,0,0,1,0,0,1,0,0,...,1,1,0,0,0,1,0,1,0,0
4,4,1,0,0,1,0,0,1,0,0,...,1,1,0,0,0,0,1,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4510,4510,0,0,1,1,0,0,0,0,1,...,1,0,0,1,0,0,1,0,0,1
4511,4511,0,0,1,1,0,0,0,0,1,...,1,0,0,1,0,0,1,0,0,1
4512,4512,0,0,1,1,0,0,0,0,1,...,1,1,0,0,0,0,1,1,0,0
4513,4513,0,0,1,1,0,0,0,0,1,...,1,0,1,0,0,0,1,1,0,0


In [148]:
train_df=dataset.sample(frac=train_test_split,random_state=200)
test_df=dataset.drop(train_df.index)

train_size = train_df.shape[0]

  and should_run_async(code)


In [149]:
train_dataset = (tf.data.Dataset.from_tensor_slices((train_df.iloc[:,1:].values)).shuffle(train_size).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices((test_df.iloc[:,1:].values)).shuffle(train_size).batch(1))

In [150]:

for epoch in range(1, epochs + 1):
  start_time = time.time()
  for train_x in train_dataset:
    train_x = tf.cast(train_x,tf.float32)
    train_step(true_state_model, train_x, optimizer)
  end_time = time.time()

  loss = tf.keras.metrics.Mean()
  for test_x in test_dataset:
    test_x = tf.cast(test_x,tf.float32)
    loss(compute_loss(true_state_model, test_x))
  elbo = -loss.result()
  display.clear_output(wait=False)
  print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'
        .format(epoch, elbo, end_time - start_time))
#   generate_and_save_images(model, epoch, test_sample)

Epoch: 300, Test set ELBO: -16.70627212524414, time elapse for current epoch: 0.39061713218688965


In [166]:
def get_state_pred_pair(model, state):
  state = tf.cast(state, tf.float32)
  mean, logvar = model.encode(state)
  z = model.reparameterize(mean, logvar)
  predictions = model.sample(z)
    
  state_oh = tf.reshape(state,(-1,2,3)) #tf.one_hot(tf.argmax(tf.reshape(state,(-1,3)),axis=1),depth=3)
  pred_oh = model.get_oh_output(predictions)#tf.one_hot(tf.argmax(tf.nn.softmax(tf.reshape(predictions,(-1,2,3))),axis=2),depth=3)
#   print(f"{state_oh.shape}, {pred_oh.shape}")
  return state_oh, pred_oh

  and should_run_async(code)


In [168]:
# test_dataset = (tf.data.Dataset.from_tensor_slices((dataset.iloc[:200,1:].values)).shuffle(train_size).batch(batch_size))
#train_dataset2 = (tf.data.Dataset.from_tensor_slices((train_df.iloc[:,1:].values)).shuffle(train_size).batch(1))

total_matches = 0
total = 0
nodes = 0
sum_diffs_sqrd = 0
state_pred_pairs = []
state_pred_pair_tree_vis = []
for test_x in test_dataset:
    total +=1
    state_oh, pred_oh = get_state_pred_pair(true_state_model, test_x)
    
    state_pred_pairs.append([state_oh, pred_oh])
    state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(state_oh), TrueStateTreeGraphViz(pred_oh)])
    
    diffs = np.rint(state_oh.numpy()) - np.rint(pred_oh.numpy())
#     diffs = get_state_diff(true_state_model,test_x)
    nodes += len(diffs.flatten())
    diffs_sqrd = np.sum(diffs*diffs)
    sum_diffs_sqrd += diffs_sqrd
    if not diffs_sqrd >0:
#       print(diffs)
#     else:
      total_matches += 1
#       print("Match")
#       print(diffs)

print(f"accuracy = {total_matches}/{total} = {total_matches/total}, \nmean of squared diffs = {sum_diffs_sqrd}/{nodes}={sum_diffs_sqrd/nodes}\npercentage wrong = ({sum_diffs_sqrd}/{2})/({nodes}/{3})={(sum_diffs_sqrd/2)/(nodes/3)}")
display_tree_pairs(state_pred_pair_tree_vis)


  and should_run_async(code)


accuracy = 0/226 = 0.0, 
mean of squared diffs = 6662.0/17628=0.3779214885409576
percentage wrong = (6662.0/2)/(17628/3)=0.5668822328114363
226


object.__init__() takes exactly one argument (the instance to initialize)
This is deprecated in traitlets 4.2.This error will be raised in a future release of traitlets.
  super().__init__(**kwargs)


HBox(children=(Button(description='<<', style=ButtonStyle()), Button(description='>>', style=ButtonStyle()), I…

HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\xfa\x00\x00\x01[\x08\x02\x00\x00\…

In [161]:
# true_state_model.build((1,78))
true_state_model(np.ones((1,78)))
true_state_model.save('models/trueStateVAE_11_L8',overwrite=True)# best one is L8 3 or 4 or 6 or 7 (78%).. need to check
# m2 = Tr
# # true_state_model.save('models/trueStateVAE_1')
# true_state_model.compile()
# # true_state_model._set_inputs(test_x)
# true_state_model.fit(dataset, epochs=1)
# tf.keras.models.save_model(
#     true_state_model,
#     'models/trueStateVAE_1',
#     overwrite=False,
#     include_optimizer=True
# )

  and should_run_async(code)


In [157]:
m2 = tf.keras.models.load_model(
    'models/trueStateVAE_11_L8',
)

  and should_run_async(code)




In [158]:


# m2.sample()
total_matches = 0
total = 0
nodes = 0
sum_diffs_sqrd = 0
state_pred_pairs = []
state_pred_pair_tree_vis = []
for test_x in test_dataset:
    total +=1
    state_oh, pred_oh = get_state_pred_pair(m2, test_x)
    
    state_pred_pairs.append([state_oh, pred_oh])
    state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(state_oh), TrueStateTreeGraphViz(pred_oh)])
    
    diffs = np.rint(state_oh.numpy()) - np.rint(pred_oh.numpy())
#     diffs = get_state_diff(true_state_model,test_x)
    nodes += len(diffs.flatten())
    diffs_sqrd = np.sum(diffs*diffs)
    sum_diffs_sqrd += diffs_sqrd
    if not diffs_sqrd >0:
#       print(diffs)
#     else:
      total_matches += 1
#       print("Match")
#       print(diffs)

print(f"accuracy = {total_matches}/{total} = {total_matches/total}, \nmean of squared diffs = {sum_diffs_sqrd}/{nodes}={sum_diffs_sqrd/nodes}\npercentage wrong = ({sum_diffs_sqrd}/{2})/({nodes}/{3})={(sum_diffs_sqrd/2)/(nodes/3)}")
display_tree_pairs(state_pred_pair_tree_vis)

  and should_run_async(code)


accuracy = 0/226 = 0.0, 
mean of squared diffs = 1840.0/17628=0.10437939641479464
percentage wrong = (1840.0/2)/(17628/3)=0.15656909462219196
226


object.__init__() takes exactly one argument (the instance to initialize)
This is deprecated in traitlets 4.2.This error will be raised in a future release of traitlets.
  super().__init__(**kwargs)


HBox(children=(Button(description='<<', style=ButtonStyle()), Button(description='>>', style=ButtonStyle()), I…

HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\xfa\x00\x00\x01[\x08\x02\x00\x00\…

In [24]:
m2.encoder.summary()

Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_30 (Dense)            (None, 10000)             790000    
                                                                 
 dense_31 (Dense)            (None, 6000)              60006000  
                                                                 
 dense_32 (Dense)            (None, 16)                96016     
                                                                 
Total params: 60,892,016
Trainable params: 60,892,016
Non-trainable params: 0
_________________________________________________________________
