### MDN References:

[useful pytorch reference](https://github.com/tonyduan/mixture-density-network)

[keras version](https://github.com/cpmpercussion/keras-mdn-layer)

[another keras version](https://github.com/omimo/Keras-MDN/blob/master/kmdn/mdn.py)


In [1]:
!pip install tensorflow-probability
!pip install graphviz

You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/usr/bin/python3 -m pip install --upgrade pip' command.[0m


In [1]:
import MDN
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
import pandas as pd
from ProcessTrueStateActionData import read_df_in_chunks

from true_state_viewer import TrueStateTreeGraphViz, display_tree_pairs

import time
from IPython import display

### Load and preprocess the data 
(produce tf train+test datasets)

In [2]:
batch_size = 64
train_test_split = 0.8

LATENT_SIZE = 8 # (mdn output_dimension)
NUMBER_MIXTURES = 5

EPOCHS = 10

DATA_CAP = 500000#1_000_000

  and should_run_async(code)


In [3]:
def get_columns_for_training():
    true_states = ["pre","blue","red"]
    ts_columns = {}
    for true_state in true_states:
        ts_columns[true_state] = []
        for node in range(13):
            ts_columns[true_state].append(f"{node}_ts_{true_state}_known_status")
            ts_columns[true_state].append(f"{node}_ts_{true_state}_access_status")
    return ts_columns

In [4]:
cols_dict = get_columns_for_training()
pre_cols, blue_cols, red_cols = cols_dict["pre"], cols_dict["blue"], cols_dict["red"]
all_cols = pre_cols + blue_cols + red_cols

In [5]:
df=pd.read_parquet("csv_data/TrueStatesObsActsRwds_1221_4000_B_Line.parquet").iloc[:DATA_CAP]
df = df[all_cols].astype("category")

In [6]:
df.memory_usage(deep=True)

  and should_run_async(code)


Index                      4000000
0_ts_pre_known_status       500124
0_ts_pre_access_status      500116
1_ts_pre_known_status       500132
1_ts_pre_access_status      500132
                            ...   
10_ts_red_access_status     500132
11_ts_red_known_status      500124
11_ts_red_access_status     500132
12_ts_red_known_status      500124
12_ts_red_access_status     500132
Length: 79, dtype: int64

In [7]:
train_df=df.sample(frac=train_test_split,random_state=42)
train_pre_df = train_df[pre_cols]
train_blue_df = train_df[blue_cols]
train_red_df = train_df[red_cols]

test_df=df.drop(train_df.index)
test_pre_df = test_df[pre_cols]
test_blue_df = test_df[blue_cols]
test_red_df = test_df[red_cols]

train_size = train_df.shape[0]
test_size = test_df.shape[0]

  and should_run_async(code)


In [8]:
print(train_size)

400000


  and should_run_async(code)


In [9]:
train_dataset = (tf.data.Dataset.from_tensor_slices(((train_pre_df.values,train_blue_df.values),train_red_df.values)).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices(((test_pre_df.values,test_blue_df.values),test_red_df.values)).shuffle(test_size).batch(1))

In [10]:
for row in train_dataset.take(1):
  print(row)

((<tf.Tensor: shape=(64, 26), dtype=int64, numpy=
array([[1, 0, 1, ..., 0, 1, 0],
       [0, 0, 2, ..., 2, 1, 0],
       [0, 0, 0, ..., 0, 1, 0],
       ...,
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 2, ..., 2, 1, 0],
       [1, 0, 1, ..., 0, 1, 0]])>, <tf.Tensor: shape=(64, 26), dtype=int64, numpy=
array([[1, 0, 1, ..., 0, 1, 0],
       [0, 0, 2, ..., 2, 1, 0],
       [0, 0, 0, ..., 0, 1, 0],
       ...,
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 2, ..., 2, 1, 0],
       [1, 0, 1, ..., 0, 1, 0]])>), <tf.Tensor: shape=(64, 26), dtype=int64, numpy=
array([[1, 0, 1, ..., 0, 1, 0],
       [0, 0, 2, ..., 2, 1, 0],
       [0, 0, 0, ..., 0, 1, 0],
       ...,
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 2, ..., 2, 1, 0],
       [1, 0, 1, ..., 0, 1, 0]])>)


### Create an MDN based model with pretrained encoder/decoder layers

In [11]:
class RedTSPrediction(tf.keras.Model):

    def __init__(self, vae_path, latent_size, num_mixtures):
        super().__init__()
        self.ts_vae = tf.keras.models.load_model(vae_path)
        self.ts_vae.trainable = False
        self.encoder = self.ts_vae.encoder
        self.decoder = self.ts_vae.decoder

        self.ts_dense = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        self.cross_dense = tf.keras.layers.Dense(2048, activation=tf.nn.relu)

        self.mdn = MDN.MDN(output_dimension=latent_size, num_mixtures=num_mixtures)

    @tf.function(input_signature=[tf.TensorSpec(shape=[None,78], dtype=tf.float32)])
    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        z = self.reparameterize(mean, logvar)
        return z
    
    @tf.function(input_signature=[tf.TensorSpec(shape=[None,LATENT_SIZE], dtype=tf.float32),tf.TensorSpec(shape=[None,LATENT_SIZE], dtype=tf.float32)])
    def reparameterize(self, mean, logvar):
        #     eps = tf.random.normal(shape=mean.shape)
        eps = tf.random.normal(shape=tf.shape(mean))

        return eps * tf.exp(logvar * .5) + mean
 
    @tf.function(input_signature=[tf.TensorSpec(shape=[None,LATENT_SIZE], dtype=tf.float32)])
    def decode(self, latent, apply_sigmoid=False):
        logits = self.decoder(latent)
        if apply_sigmoid:
          probs = tf.sigmoid(logits)
        return logits #self.decoder(latent_pred)

    @tf.function#(input_signature=[tf.TensorSpec(shape=[None,26], dtype=tf.uint8),tf.TensorSpec(shape=[None,26], dtype=tf.uint8)])
    def call(self, inputs):
        pre_ts = inputs[0]
        blue_ts = inputs[1]
        
        pre_ts_oh = tf.cast(tf.reshape(tf.one_hot(pre_ts,3),(-1,78)),tf.float32)
#         pre_ts_access = tf.reshape(tf.one_hot(pre_ts[:,13:],3),(-1,39))
        
        blue_ts_oh = tf.cast(tf.reshape(tf.one_hot(blue_ts,3),(-1,78)),tf.float32)
#         blue_ts_access = tf.reshape(tf.one_hot(blue_ts[:,13:],3),(-1,39))
        
#         print(pre_ts[:,:])
#         print(pre_ts_oh)
#         print(blue_ts)
#         blue_ts_kn = K.print_tensor(blue_ts[:,:13], message='blue known = ')
#         blue_ts_known = K.print_tensor(blue_ts_known[:,:13], message='blue known OH = ')
#         print(blue_ts_known.shape)

        pre_ts_encoded = self.encode(pre_ts_oh)
#         mean, logvar = self.ts_vae.encode(pre_ts_oh)
#         mean, logvar = tf.split(self.encoder(pre_ts_oh), num_or_size_splits=2, axis=1)
#         pre_ts_encoded = self.reparameterize(mean, logvar)
        pre_ts_encoded = self.ts_dense(pre_ts_encoded)
        
        blue_ts_encoded = self.encode(blue_ts_oh)
        blue_ts_encoded = self.ts_dense(blue_ts_encoded)

        combined = tf.keras.layers.concatenate([pre_ts_encoded, blue_ts_encoded])

        combined_hidden = self.cross_dense(combined)

        mdn_out = self.mdn(combined_hidden)

        return mdn_out
        


In [12]:
red_ts_predictor = RedTSPrediction('models/trueStateVAE_7_L8', latent_size=LATENT_SIZE, num_mixtures=NUMBER_MIXTURES)



### Training loop

In [13]:
for row in test_dataset.take(1):
#   print(row)
  out = red_ts_predictor(row[0])
  print(out)

  and should_run_async(code)


tf.Tensor(
[[ 0.04282317 -0.12412542  0.2482442   0.06926742 -0.21604858  0.13568708
  -0.00556406 -0.11086847 -0.13724126 -0.14385873 -0.01991572 -0.1749891
   0.12630635 -0.2629868  -0.16388357 -0.16506119  0.3151481  -0.01677262
  -0.04931451 -0.36933324  0.02224669  0.17810209  0.14818731 -0.01163785
   0.15364993 -0.2868906   0.12311868  0.07172373 -0.22601117 -0.04048943
   0.01212773  0.31921947 -0.01490611  0.02528222 -0.26694006  0.21455868
  -0.07336514  0.2854832  -0.16609979 -0.11801924  1.1694423   1.0336866
   1.1397274   1.1471629   0.8263541   0.97887087  1.2936279   0.8449799
   0.87034327  1.0315249   0.8567244   1.026631    0.84329104  1.0977898
   1.0605912   0.8532094   1.5071747   0.95168877  0.9699369   0.8032143
   0.9413977   0.9404692   0.79859227  1.051352    1.1604077   0.894228
   1.1419699   1.090362    0.95945257  0.92383623  0.7799355   0.93067586
   0.9201996   0.8948295   0.68122625  0.7172227   1.0967408   1.1549678
   1.0616311   1.1668551  -0.209983

In [14]:
@tf.function
def forward_pass(model, x, y):
    out = model(x)
    
    y_oh = tf.reshape(tf.one_hot(y,3),(-1,78))
    y_encoded = model.encode(y_oh)
    
    return out, y_encoded

@tf.function
def decode_zs(model, x, y):
    x_oh = model.ts_vae.get_oh_output(model.decode(tf.cast(x,tf.float32)))
    y_oh = model.ts_vae.get_oh_output(model.decode(tf.cast(y,tf.float32)))
    
    return x_oh, y_oh
    

@tf.function
def compute_loss(model, x, y, loss_func):
    out, y_encoded = forward_pass(model, x, y)
    
    loss = loss_func(y_encoded, out)
    
    return loss
    

@tf.function
def train_step(model, x, y, optimizer, loss_func):
    """Executes one training step and returns the loss.

    This function computes the loss and gradients, and uses the latter to
    update the model's parameters.
    """
#     y_encoded = model.encode(y)
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x, y, loss_func)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

In [15]:
# red_ts_predictor.compile(loss=get_mixture_loss_func(LATENT_SIZE,NUMBER_MIXTURES,red_ts_predictor.encode), optimizer=tf.keras.optimizers.Adam(),metrics=['mean_squared_error'])
# red_ts_predictor.build(((1,78),(1,78)))
red_ts_predictor.summary()


Model: "red_ts_prediction"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 true_state_vae_4 (TrueState  multiple                 121736094 
 VAE)                                                            
                                                                 
 sequential_8 (Sequential)   (None, 16)                60892016  
                                                                 
 sequential_9 (Sequential)   (None, 78)                60844078  
                                                                 
 dense (Dense)               multiple                  1152      
                                                                 
 dense_1 (Dense)             multiple                  526336    
                                                                 
 mdn (MDN)                   multiple                  174165    
                                                 

In [16]:
loss_func = MDN.get_mixture_loss_func(LATENT_SIZE,NUMBER_MIXTURES)
optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
# train_dataset=test_dataset.take(10)

# EPOCHS = 500
for epoch in range(1, EPOCHS + 1):
    count = 0
    start_time = time.time()
    for train_x, train_y in train_dataset:
        if (count %1000) == 0:
            print(f"{count}={count*batch_size} samples")
        train_step(red_ts_predictor, train_x, train_y, optimizer, loss_func)
        count += 1
    end_time = time.time()

    
    total_matches = 0
    total = 0
    nodes = 0
    sum_diffs_sqrd = 0
    state_pred_pairs = []
    state_pred_pair_tree_vis = []
    loss = tf.keras.metrics.Mean()
    for test_x, test_y in test_dataset.take(1000):#.take(10):#test_dataset:
        total += 1
        if (total %100) == 0:
            print(f"{total}={total*1} samples")
        out, y_encoded = forward_pass(red_ts_predictor, test_x, test_y)
        loss_val = loss_func(y_encoded, out)
        loss(loss_val)
        
#         sampled_out = MDN.sample_from_output(out[0].numpy(), output_dim=LATENT_SIZE, num_mixes=NUMBER_MIXTURES)
#         pred_oh, y_oh = decode_zs(red_ts_predictor, sampled_out, y_encoded)
        
        
#         state_pred_pairs.append([y_oh, pred_oh])
# #         state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(y_oh), TrueStateTreeGraphViz(pred_oh)])
# #         loss(compute_loss(red_ts_predictor, test_x, test_y, loss_func))
#         diffs = np.rint(y_oh.numpy()) - np.rint(pred_oh.numpy())
#     #     diffs = get_state_diff(true_state_model,test_x)
#         nodes += len(diffs.flatten())
#         diffs_sqrd = np.sum(diffs*diffs)
#         sum_diffs_sqrd += diffs_sqrd
#         if not diffs_sqrd >0:
#     #       print(diffs)
#     #     else:
#           total_matches += 1
#     #       print("Match")
    #       print(diffs)

    loss = loss.result()
    display.clear_output(wait=False)
#     print(f"accuracy = {total_matches}/{total} = {total_matches/total}, \nmean of squared diffs = {sum_diffs_sqrd}/{nodes}={sum_diffs_sqrd/nodes}\npercentage wrong = ({sum_diffs_sqrd}/{2})/({nodes}/{3})={(sum_diffs_sqrd/2)/(nodes/3)}")
    print('Epoch: {}, Test set loss: {}, time elapse for current epoch: {}'
        .format(epoch, loss, end_time - start_time))
#     display_tree_pairs(state_pred_pair_tree_vis)

Epoch: 4, Test set loss: 4.265441417694092, time elapse for current epoch: 34.03437352180481
0=0 samples
1000=64000 samples
2000=128000 samples
3000=192000 samples
4000=256000 samples
5000=320000 samples
6000=384000 samples
100=100 samples
200=200 samples
300=300 samples
400=400 samples
500=500 samples
600=600 samples
700=700 samples
800=800 samples


In [105]:
# pre_ts_encoded = red_ts_predictor.ts_vae.encode(np.zeros((1,78),dtype=np.float32))
# red_ts_predictor.ts_dense(pre_ts_encoded)

red_ts_predictor((np.zeros((1,26),dtype=np.uint8),np.zeros((1,26),dtype=np.uint8)))

red_ts_predictor.save('models/RedTSPredictionMDN_1',overwrite=True)

  and should_run_async(code)


### Test and evaluate

In [106]:
m2 = tf.keras.models.load_model(
    'models/RedTSPredictionMDN_1',
)

  and should_run_async(code)






In [109]:
total_matches = 0
total = 0
nodes = 0
sum_diffs_sqrd = 0
state_pred_pairs = []
state_pred_pair_tree_vis = []
loss = tf.keras.metrics.Mean()
for test_x, test_y in test_dataset.take(200):#.take(10):#test_dataset:
    total += 1
    out, y_encoded = forward_pass(m2, (tf.cast(test_x[0],tf.uint8),tf.cast(test_x[1],tf.uint8)), tf.cast(test_y,tf.uint8))
    loss_val = loss_func(y_encoded, out)
    loss(loss_val)

    sampled_out = MDN.sample_from_output(out[0].numpy(), output_dim=LATENT_SIZE, num_mixes=NUMBER_MIXTURES)
    pred_oh, y_oh = decode_zs(m2, sampled_out, y_encoded)


    state_pred_pairs.append([y_oh, pred_oh])
    state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(y_oh), TrueStateTreeGraphViz(pred_oh)])
#         loss(compute_loss(m2, test_x, test_y, loss_func))
    diffs = np.rint(y_oh.numpy()) - np.rint(pred_oh.numpy())
#     diffs = get_state_diff(true_state_model,test_x)
    nodes += len(diffs.flatten())
    diffs_sqrd = np.sum(diffs*diffs)
    sum_diffs_sqrd += diffs_sqrd
    if not diffs_sqrd >0:
#       print(diffs)
#     else:
      total_matches += 1
#       print("Match")
#       print(diffs)

loss = loss.result()
display.clear_output(wait=False)
print(f"accuracy = {total_matches}/{total} = {total_matches/total}, \nmean of squared diffs = {sum_diffs_sqrd}/{nodes}={sum_diffs_sqrd/nodes}\npercentage wrong = ({sum_diffs_sqrd}/{2})/({nodes}/{3})={(sum_diffs_sqrd/2)/(nodes/3)}")
print('Epoch: {}, Test set loss: {}, time elapse for current epoch: {}'
    .format(epoch, loss, end_time - start_time))
display_tree_pairs(state_pred_pair_tree_vis)

accuracy = 0/200 = 0.0, 
mean of squared diffs = 5998.0/15600=0.3844871794871795
percentage wrong = (5998.0/2)/(15600/3)=0.5767307692307693
Epoch: 10, Test set loss: 14.675802230834961, time elapse for current epoch: 35.34944009780884
200


object.__init__() takes exactly one argument (the instance to initialize)
This is deprecated in traitlets 4.2.This error will be raised in a future release of traitlets.
  super().__init__(**kwargs)


ExecutableNotFound: failed to execute PosixPath('dot'), make sure the Graphviz executables are on your systems' PATH

In [22]:
display_tree_pairs(state_pred_pair_tree_vis)


  and should_run_async(code)
object.__init__() takes exactly one argument (the instance to initialize)
This is deprecated in traitlets 4.2.This error will be raised in a future release of traitlets.
  super().__init__(**kwargs)


200


ExecutableNotFound: failed to execute PosixPath('dot'), make sure the Graphviz executables are on your systems' PATH

In [34]:
red_ts_predictor((np.zeros((1,26),dtype=np.uint8),np.zeros((1,26),dtype=np.uint8)))
# red_ts_predictor.save('models/RedTSPredictionMDN_1',overwrite=True)

<tf.Tensor: shape=(1, 85), dtype=float32, numpy=
array([[-1.59288092e+01,  2.43123627e+01, -9.79987907e+00,
        -4.56241965e-01,  4.26264229e+01, -1.14233541e+01,
        -1.38954401e+01, -1.00442787e+02,  3.79581871e+01,
         1.23023354e+02,  4.74715042e+01,  2.27442184e+02,
         5.22331190e+00,  1.93045826e+01,  1.20778618e+01,
        -1.74362473e+02, -3.34122849e+01,  6.78278732e+01,
        -1.11614923e+01, -9.03002834e+00,  4.34372482e+01,
        -9.98210049e+00, -2.08592548e+01, -5.44174194e+01,
         1.85143795e+01,  4.42520027e+01, -2.73313618e+01,
        -1.41870298e+01,  2.72996349e+01, -2.76405201e+01,
         2.07074947e+01, -8.53032150e+01, -2.01825256e+01,
        -1.49921865e+01, -4.34134254e+01,  2.15573845e+01,
         4.63493843e+01, -1.76734805e+00,  3.45794029e+01,
        -2.39422226e+01,  3.40773659e+01,  1.19209290e-07,
         1.19209290e-07,  1.77688046e+01,  1.19209290e-07,
         4.25124168e-03,  2.44193673e-02,  5.08455124e+01,
       

In [39]:
mean, logvar = red_ts_predictor.ts_vae.encode(np.zeros((1,78),dtype=np.float32))
z = red_ts_predictor.ts_vae.reparameterize(mean, logvar)

print(z)

tf.Tensor(
[[ 0.82098275  0.31760526 -0.03596312  1.1379313   0.25326672 -0.3022647
   0.4376471   0.3620196 ]], shape=(1, 8), dtype=float32)


In [40]:
dir(red_ts_predictor)

['_SCALAR_UPRANKING_ON',
 '_TF_MODULE_IGNORED_PROPERTIES',
 '__call__',
 '__class__',
 '__copy__',
 '__deepcopy__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_activity_regularizer',
 '_add_trackable',
 '_add_trackable_child',
 '_add_variable_with_custom_getter',
 '_assert_compile_was_called',
 '_assert_weights_created',
 '_auto_config',
 '_auto_get_config',
 '_auto_track_sub_layers',
 '_autocast',
 '_autographed_call',
 '_base_model_initialized',
 '_build_input_shape',
 '_call_spec',
 '_callable_losses',
 '_captured_weight_regularizer',
 '_cast_single_input',
 '_check_call_args',
 '_checkpoint',
 '_checkpoint_dependencies',
 '_clear_losses',
