### MDN References:

[useful pytorch reference](https://github.com/tonyduan/mixture-density-network)

[keras version](https://github.com/cpmpercussion/keras-mdn-layer)

[another keras version](https://github.com/omimo/Keras-MDN/blob/master/kmdn/mdn.py)


In [None]:
import MDN
import tensorflow as tf
import tensorflow.keras.backend as K
import pandas as pd
from ProcessTrueStateActionData import read_df_in_chunks

### Load and preprocess the data 
(produce tf train+test datasets)

In [None]:
batch_size = 64
train_test_split = 0.8

In [4]:
def get_columns_for_training():
    true_states = ["pre","blue","red"]
    ts_columns = {}
    for true_state in true_states:
        ts_columns[true_state] = []
        for node in range(13):
            ts_columns[true_state].append(f"{node}_ts_{true_state}_known_status")
            ts_columns[true_state].append(f"{node}_ts_{true_state}_access_status")
    return ts_columns

In [5]:
cols_dict = get_columns_for_training()
pre_cols, blue_cols, red_cols = cols_dict["pre"], cols_dict["blue"], cols_dict["red"]
all_cols = pre_cols + blue_cols + red_cols

In [6]:
df=pd.read_parquet("csv_data/TrueStatesObsActsRwds_1221_4000_B_Line.parquet")
df = df[all_cols].astype("category")

In [7]:
df.memory_usage(deep=True)

  and should_run_async(code)


Index                      39072000
0_ts_pre_known_status       4884124
0_ts_pre_access_status      4884116
1_ts_pre_known_status       4884132
1_ts_pre_access_status      4884132
                             ...   
10_ts_red_access_status     4884132
11_ts_red_known_status      4884124
11_ts_red_access_status     4884132
12_ts_red_known_status      4884124
12_ts_red_access_status     4884132
Length: 79, dtype: int64

In [8]:
train_df=df.sample(frac=train_test_split,random_state=42)
train_pre_df = train_df[pre_cols]
train_blue_df = train_df[blue_cols]
train_red_df = train_df[red_cols]

test_df=df.drop(train_df.index)
test_pre_df = test_df[pre_cols]
test_blue_df = test_df[blue_cols]
test_red_df = test_df[red_cols]

train_size = train_df.shape[0]
test_size = test_df.shape[0]

  and should_run_async(code)


In [9]:
print(train_size)

3907200


  and should_run_async(code)


In [14]:
train_dataset = (tf.data.Dataset.from_tensor_slices(((train_pre_df.values,train_blue_df.values),train_red_df.values)).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices(((test_pre_df.values,test_blue_df.values),test_red_df.values)).shuffle(test_size).batch(1))

In [15]:
for row in train_dataset.take(1):
  print(row)

((<tf.Tensor: shape=(64, 26), dtype=int64, numpy=
array([[0, 0, 0, ..., 0, 1, 0],
       [0, 0, 2, ..., 0, 2, 2],
       [0, 0, 0, ..., 0, 1, 0],
       ...,
       [0, 0, 2, ..., 2, 1, 0],
       [0, 0, 2, ..., 2, 1, 0],
       [1, 0, 2, ..., 0, 2, 2]])>, <tf.Tensor: shape=(64, 26), dtype=int64, numpy=
array([[0, 0, 0, ..., 0, 1, 0],
       [0, 0, 2, ..., 0, 2, 2],
       [0, 0, 0, ..., 0, 1, 0],
       ...,
       [0, 0, 2, ..., 2, 1, 0],
       [0, 0, 2, ..., 2, 1, 0],
       [1, 0, 2, ..., 0, 2, 2]])>), <tf.Tensor: shape=(64, 26), dtype=int64, numpy=
array([[0, 0, 0, ..., 0, 1, 0],
       [0, 0, 2, ..., 0, 2, 2],
       [0, 0, 0, ..., 0, 1, 0],
       ...,
       [0, 0, 2, ..., 2, 1, 0],
       [0, 0, 2, ..., 2, 1, 0],
       [1, 0, 2, ..., 0, 2, 2]])>)


### Create an MDN based model with pretrained encoder/decoder layers

In [48]:
class RedTSPrediction(tf.keras.Model):

    def __init__(self, vae_path, latent_size, num_mixtures):
        super().__init__()
        self.ts_vae = tf.keras.models.load_model(vae_path)
        self.ts_vae.trainable = False
        self.encoder = self.ts_vae.encoder
        self.decoder = self.ts_vae.decoder

        self.ts_dense = tf.keras.layers.Dense(128, activation=tf.nn.relu)
        self.cross_dense = tf.keras.layers.Dense(2048, activation=tf.nn.relu)

        self.mdn = MDN.MDN(output_dimension=latent_size, num_mixtures=num_mixtures)

    def call(self, inputs):
        pre_ts = inputs[0]
        blue_ts = inputs[1]
        
        pre_ts_oh = tf.reshape(tf.one_hot(pre_ts,3),(-1,78))
#         pre_ts_access = tf.reshape(tf.one_hot(pre_ts[:,13:],3),(-1,39))
        
        blue_ts_oh = tf.reshape(tf.one_hot(blue_ts,3),(-1,78))
#         blue_ts_access = tf.reshape(tf.one_hot(blue_ts[:,13:],3),(-1,39))
        
        print(pre_ts[:,:])
        print(pre_ts_oh)
#         print(blue_ts)
#         blue_ts_kn = K.print_tensor(blue_ts[:,:13], message='blue known = ')
#         blue_ts_known = K.print_tensor(blue_ts_known[:,:13], message='blue known OH = ')
#         print(blue_ts_known.shape)

        pre_ts = self.ts_dense(self.encoder(pre_ts_oh))
        blue_ts = self.ts_dense(self.encoder(blue_oh))

        combined = tf.layers.concatenate([pre_ts, blue_ts])

        combined_hidden = self.cross_dense(combined)

        mdn_out = self.mdn(combined_hidden)

        return mdn_out

    def decode(self, latent_pred):
        return self.decoder(latent_pred)
        


In [49]:
red_ts_predictor = RedTSPrediction('models/trueStateVAE_7_L8', latent_size=8, num_mixtures=5)



### Training loop

In [None]:
for row in test_dataset.take(1):
#   print(row)
  out = red_ts_predictor(row[0])
  print(out)

### Test and evaluate