### MDN References:

[useful pytorch reference](https://github.com/tonyduan/mixture-density-network)

[keras version](https://github.com/cpmpercussion/keras-mdn-layer)

[another keras version](https://github.com/omimo/Keras-MDN/blob/master/kmdn/mdn.py)


In [1]:
import MDN
from MDN import sample_from_output
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
import pandas as pd
from ProcessTrueStateActionData import read_df_in_chunks

from true_state_viewer import TrueStateTreeGraphViz, display_tree_red_preds

import time
from IPython import display

### Load and preprocess the data 
(produce tf train+test datasets)

In [2]:
batch_size = 32
train_test_split = 0.98

STATE_SIZE = 42 # (mdn output_dimension)
NUMBER_MIXTURES = 10

EPOCHS = 50

DATA_CAP = 3_000_000

  and should_run_async(code)


In [3]:
data_path = 'logs/APPO/TrueStates_200_1000_Meander_small/data'
afterstates = np.load(data_path + '/afterstates.npy')
next_states = np.load(data_path + '/next_states.npy')
afterstates = np.array(afterstates, dtype=np.float32)
next_states = np.array(next_states, dtype=np.float32)
divide = int(train_test_split * afterstates.shape[0])
afterstates_train = afterstates[:divide,:]
next_states_train = next_states[:divide,:]
afterstates_test = afterstates[divide:,:]
next_states_test = next_states[divide:,:]
print(afterstates.shape[0] - divide)

3960


In [4]:
train_dataset = (tf.data.Dataset.from_tensor_slices((afterstates_train, next_states_train)).shuffle(afterstates.shape[0]).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices((afterstates_test, next_states_test)).shuffle(afterstates.shape[0]).batch(batch_size))

### Create an MDN based model with pretrained encoder/decoder layers

### Training loop

In [111]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_probability as tfp
from keras.layers import Flatten, LSTM, Input

def prior(kernel_size, bias_size, dtype=None):
    n = kernel_size + bias_size
    prior_model = keras.Sequential(
        [
            tfp.layers.DistributionLambda(
                lambda t: tfp.distributions.MultivariateNormalDiag(
                    loc=tf.zeros(n), scale_diag=tf.ones(n)
                )
            )
        ]
    )
    return prior_model


def posterior(kernel_size, bias_size, dtype=None):
    n = kernel_size + bias_size
    posterior_model = keras.Sequential(
        [
            tfp.layers.VariableLayer(
                tfp.layers.MultivariateNormalTriL.params_size(n), dtype=dtype
            ),
            tfp.layers.MultivariateNormalTriL(n),
        ]
    )
    return posterior_model


def create_probablistic_bnn_model(train_size):
    inputs = Input(shape=(42,))

    # Create hidden layers with weight uncertainty using the DenseVariational layer.
    features = tfp.layers.DenseVariational(
        units=8,
        make_prior_fn=prior,
        make_posterior_fn=posterior,
        kl_weight=1 / train_size,
        activation="sigmoid",
    )(inputs)
    
    features = tfp.layers.DenseVariational(
        units=8,
        make_prior_fn=prior,
        make_posterior_fn=posterior,
        kl_weight=1 / train_size,
        activation="sigmoid",
    )(features)

    # Create a probabilisticå output (Normal distribution), and use the `Dense` layer
    # to produce the parameters of the distribution.
    # We set units=2 to learn both the mean and the variance of the Normal distribution.
    distribution_params = layers.Dense(units=84)(features)
    outputs = tfp.layers.IndependentNormal(42)(distribution_params)

    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


"""
Since the output of the model is a distribution, rather than a point estimate,
we use the [negative loglikelihood](https://en.wikipedia.org/wiki/Likelihood_function)
as our loss function to compute how likely to see the true data (targets) from the
estimated distribution produced by the model.
"""


def negative_loglikelihood(targets, estimated_distribution):
    return -estimated_distribution.log_prob(targets)


num_epochs = 1000
prob_bnn_model = create_probablistic_bnn_model(1000)


prob_bnn_model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=0.0001),
        loss=negative_loglikelihood,
        metrics=[keras.metrics.RootMeanSquaredError()],
    )

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=7)

print(prob_bnn_model.summary())

#model.fit(train_dataset,epochs=250, verbose=1, callbacks=[callback])

prob_bnn_model.fit(afterstates_train, next_states_train, epochs=250, validation_split=0.1, verbose=1, callbacks=[callback], batch_size=128)


Model: "model_18"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_19 (InputLayer)       [(None, 42)]              0         
                                                                 
 dense_variational_31 (Dense  (None, 8)                59684     
 Variational)                                                    
                                                                 
 dense_variational_32 (Dense  (None, 8)                2700      
 Variational)                                                    
                                                                 
 dense_18 (Dense)            (None, 84)                756       
                                                                 
 independent_normal_18 (Inde  ((None, 42),             0         
 pendentNormal)               (None, 42))                        
                                                          


KeyboardInterrupt



In [103]:
prob_bnn_model(afterstates_train[1:1])
prediction_distribution.mean().numpy().shape

AttributeError: 'numpy.float32' object has no attribute 'numpy'

In [None]:

prediction_distribution = prob_bnn_model(afterstates_train[0:100])
prediction_mean = prediction_distribution.mean().numpy()
prediction_stdv = prediction_distribution.stddev().numpy()

# The 95% CI is computed as mean ± (1.96 * stdv)
upper = (prediction_mean + (1.96 * prediction_stdv)).tolist()
lower = (prediction_mean - (1.96 * prediction_stdv)).tolist()
prediction_stdv = prediction_stdv.tolist()

for idx in range(1):
    print(
        f"Prediction mean: {round(prediction_mean[idx][0], 2)}, "
        f"stddev: {round(prediction_stdv[idx][0], 2)}, "
        f"95% CI: [{round(upper[idx][0], 3)} - {round(lower[idx][0], 2)}]"
        f" - Actual: {afterstates_train[idx]}"
    )
    print(np.array(prediction_mean.mean(axis=0)>0.5, dtype=np.float32))


In [81]:
prediction_distribution = prob_bnn_model(afterstates_test)
prediction_mean = prediction_distribution.mean().numpy()
count = 0
for i in range(3960):
    if (np.array(prediction_mean[i]>0.5, dtype=np.float32)==next_states_test[i]).all():
        count += 1
count

0

In [78]:
np.array(prediction_mean[i]>0.5, dtype=np.float32)

array([0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1.,
       0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.,
       0., 0., 0., 1., 0., 0., 0., 1.], dtype=float32)

In [79]:
next_states_test[i]

array([0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1.,
       0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0.,
       1., 0., 1., 0., 0., 0., 0., 1.], dtype=float32)

In [108]:

from sklearn.naive_bayes import GaussianNB
clf = GaussianNB()
clf.fit(afterstates_train, next_states_train)
clf_pf = GaussianNB()
clf_pf.partial_fit(X, Y, np.unique(Y))

ValueError: y should be a 1d array, got an array of shape (194040, 42) instead.

In [98]:
prediction_distribution = model.predict(afterstates_test)
prediction_mean = prediction_distribution.mean(axis=-1)
count = 0
for i in range(3960):
    if (np.array(prediction_distribution[i]>0.5, dtype=np.float32)==next_states_test[i]).all():
        count += 1
count



0

In [99]:
model.predict(afterstates_test)[0]



array([ 0.328803  , -0.3529563 ,  0.79401225, -0.44880345,  0.35050374,
       -0.31568593,  0.6794046 ,  1.1513491 ,  0.46755043,  0.67215884,
        0.026644  , -0.66147274,  0.50816745, -0.01869421,  0.5954744 ,
        0.8644304 ,  0.01966348,  0.109952  ,  0.0674424 ,  0.5989681 ,
        1.2056735 , -0.52840537, -0.8211647 , -0.3841311 ,  0.0389117 ,
        1.2178085 , -1.090363  ,  0.9322548 , -1.1436393 , -0.25294337,
       -1.0193534 , -0.48505506, -1.8802378 , -0.73338646, -0.7592469 ,
       -0.8835641 ,  1.5495107 ,  1.1624637 ,  0.2516081 ,  0.2449704 ,
        0.51363486,  1.5921156 ], dtype=float32)

In [146]:
predictions = np.round(model.predict(afterstates_test))
count = 0
for i in range(3960):
    if (predictions[i]==next_states_test[i]).all():
        count += 1
count



0

In [147]:
prediction = model.predict(np.array([afterstates_train[3]]))[0]

new_state = {}

for i in range(3960):
    state = np.zeros(42)
    for i in range(42):
        state[i] = np.array(np.random.rand() < prediction[i], dtype=np.int8)
    if not state.tobytes() in new_state:
        new_state[state.tobytes()] = 1
    else:
        new_state[state.tobytes()] += 1
        
ns = {k: v for k, v in sorted(new_state.items(), key=lambda item: item[1], reverse=True)}
            
for s in ns.keys():
    print(ns[s])
    print(np.frombuffer(s))

print(len(new_state.keys()))

3886
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
14
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
12
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
12
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
7
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
6
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
6
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
5
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1.
 1

In [148]:
target = afterstates_train[10]
new_state = {}

for i in range(194040):
    if (target==afterstates_train[i]).all():
        state = np.array(next_states_train[i], dtype=np.int8)
        if not state.tobytes() in new_state:
            new_state[state.tobytes()] = 1
        else:
            new_state[state.tobytes()] += 1
        
ns = {k: v for k, v in sorted(new_state.items(), key=lambda item: item[1], reverse=True)}
            
for s in ns.keys():
    print(ns[s])
    print(np.frombuffer(s, dtype="int8"))

print(len(new_state.keys()))

73
[0 0 1 0 0 0 1 0 0 1 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 1 0 0 0
 1 0 1 0 0]
40
[0 1 0 0 0 0 1 0 0 1 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 1 0 0 0
 1 0 1 0 0]
16
[0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 1 0 0 0
 1 0 1 0 0]
10
[0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 1 0 0 0
 1 0 0 0 1]
7
[0 1 0 0 0 0 1 0 0 1 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 1 0 0 0
 1 0 1 0 0]
6
[1 0 0 0 1 0 0 0 0 1 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 1 0 0 0
 1 0 1 0 0]
5
[0 0 1 0 0 0 1 0 0 1 0 0 1 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 1 0 0 0
 1 0 1 0 0]
4
[0 0 1 0 0 0 1 0 0 1 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0
 1 0 1 0 0]
4
[0 0 1 0 0 0 1 0 0 1 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 1 0 0 1
 0 0 1 0 0]
4
[0 0 1 0 0 0 1 0 0 1 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 1 0 0 0
 1 0 1 0 0]
3
[0 0 1 0 0 0 1 0 1 0 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 1 0 0 0
 1 0 1 0 0]
2
[1 0 0 0 1 0 0 

In [82]:
afterstates_train

array([[0., 0., 1., ..., 0., 0., 1.],
       [0., 0., 1., ..., 0., 0., 1.],
       [0., 0., 1., ..., 0., 0., 1.],
       ...,
       [1., 0., 0., ..., 0., 0., 1.],
       [1., 0., 0., ..., 0., 0., 1.],
       [1., 0., 0., ..., 0., 0., 1.]], dtype=float32)

In [76]:
next_states_train[i].shape

(42,)

In [34]:
afterstate_predictor = AfterstatePrediction(state_size=STATE_SIZE, num_mixtures=NUMBER_MIXTURES)

In [35]:
for row in test_dataset.take(1):
    out = afterstate_predictor(row[0])
    print(out)

tf.Tensor(
[[-0.05244306  0.17184001  0.10252205 ... -0.46421444 -1.0943612
  -1.3317926 ]
 [ 0.09726037  0.19045556  0.37282902 ...  0.4968205  -1.7411695
  -0.8825296 ]
 [-0.01012154  0.28047228  0.2224648  ...  0.10149805 -0.963551
  -1.3187613 ]
 ...
 [-0.06594759  0.05624647  0.05861745 ...  0.30842698 -1.5060287
  -0.5204618 ]
 [ 0.05493884  0.08182326  0.2528863  ... -0.06889202 -1.8719797
  -0.89556086]
 [-0.27234328 -0.03344521  0.15791325 ... -0.48857522 -1.3138354
  -0.6336591 ]], shape=(32, 850), dtype=float32)


In [36]:
# red_ts_predictor.compile(loss=get_mixture_loss_func(LATENT_SIZE,NUMBER_MIXTURES,red_ts_predictor.encode), optimizer=tf.keras.optimizers.Adam(),metrics=['mean_squared_error'])
# red_ts_predictor.build(((1,78),(1,78)))
afterstate_predictor.summary()


Model: "afterstate_prediction_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_1 (Dense)             multiple                  0 (unused)
                                                                 
 mdn_1 (MDN)                 multiple                  36550     
                                                                 
Total params: 36,550
Trainable params: 36,550
Non-trainable params: 0
_________________________________________________________________


In [37]:
loss_func = MDN.get_mixture_loss_func(STATE_SIZE,NUMBER_MIXTURES)
optimizer = tf.keras.optimizers.Adam(.1e-5)

In [47]:

# EPOCHS = 500
for epoch in range(1, EPOCHS + 1):
    count = 0
    start_time = time.time()
    for train_x, train_y in train_dataset:
        #if (count %1000) == 0:
        #    print(f"{count}={count*batch_size} samples")
        train_step(afterstate_predictor, train_x, train_y, optimizer, loss_func)
        count += 1
    end_time = time.time()

    
    total_matches = 0
    total = 0
    nodes = 0
    sum_diffs_sqrd = 0
    state_pred_pairs = []
    state_pred_pair_tree_vis = []
    loss = tf.keras.metrics.Mean()
    loss_result, total = 0, 0
    for test_x, test_y in test_dataset.take(1000):#.take(10):#test_dataset:
        total += 1
        out, y = forward_pass(afterstate_predictor, test_x, test_y)
        loss_val = loss_func(y, out)
        loss(loss_val)
        loss_result += loss.result()
        #     print(f"accuracy = {total_matches}/{total} = {total_matches/total}, \nmean of squared diffs = {sum_diffs_sqrd}/{nodes}={sum_diffs_sqrd/nodes}\npercentage wrong = ({sum_diffs_sqrd}/{2})/({nodes}/{3})={(sum_diffs_sqrd/2)/(nodes/3)}")
    print('Epoch: {}, Test set loss: {}, time elapse for current epoch: {}'
        .format(epoch, loss_result/total, end_time - start_time))
        
#         sampled_out = MDN.sample_from_output(out[0].numpy(), output_dim=LATENT_SIZE, num_mixes=NUMBER_MIXTURES)
#         pred_oh, y_oh = decode_zs(red_ts_predictor, sampled_out, y_encoded)
        
        
#         state_pred_pairs.append([y_oh, pred_oh])
# #         state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(y_oh), TrueStateTreeGraphViz(pred_oh)])
# #         loss(compute_loss(red_ts_predictor, test_x, test_y, loss_func))
#         diffs = np.rint(y_oh.numpy()) - np.rint(pred_oh.numpy())
#     #     diffs = get_state_diff(true_state_model,test_x)
#         nodes += len(diffs.flatten())
#         diffs_sqrd = np.sum(diffs*diffs)
#         sum_diffs_sqrd += diffs_sqrd
#         if not diffs_sqrd >0:
#     #       print(diffs)
#     #     else:
#           total_matches += 1
#     #       print("Match")
    #       print(diffs)


Epoch: 1, Test set loss: -1.3719133138656616, time elapse for current epoch: 31.3272385597229
Epoch: 2, Test set loss: -3.8558034896850586, time elapse for current epoch: 31.34262490272522


Exception ignored in: <bound method WeakStructRef._cleanup of WeakStructRef(HashableWeakRef(<weakref at 0x7f4ccc700b80; dead>))>
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/tensorflow_probability/python/internal/cache_util.py", line 151, in _cleanup
    self._callback(self)
  File "/usr/local/lib/python3.8/dist-packages/tensorflow_probability/python/internal/cache_util.py", line 221, in maybe_del
    del self[key]
KeyboardInterrupt: 


KeyboardInterrupt: 

In [46]:
# pre_ts_encoded = red_ts_predictor.ts_vae.encode(np.zeros((1,78),dtype=np.float32))
# red_ts_predictor.ts_dense(pre_ts_encoded)

np.array(sample_from_output(afterstate_predictor(np.zeros((1,42)))[0].numpy(), 42, 5)>0.5, dtype=np.int8)

#afterstate_predictor.save('models/afterstate_predictor',overwrite=True)

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
      dtype=int8)

### Test and evaluate

In [26]:
m2 = tf.keras.models.load_model(
    'models/afterstate_predictor',
)

  and should_run_async(code)






In [27]:
total_matches = 0
total = 0
nodes = 0
sum_diffs_sqrd = 0
state_pred_pairs = []
state_pred_pair_tree_vis = []
state_pred_samples_correct = []

red_change_test_indices = []
no_change_indices = []

loss = tf.keras.metrics.Mean()

num_eval_samples = 10

for i, (test_x, test_y) in enumerate(test_dataset.take(3000)):#.take(10):#test_dataset:
    pre_ts = tf.reshape(tf.one_hot(test_x[0],3),(-1,2,3))
    blue_ts = tf.reshape(tf.one_hot(test_x[1],3),(-1,2,3))
    
    out, y_encoded = forward_pass(m2, test_x, test_y)
#     y_oh = decode_z(m2, y_encoded)
    
    y_oh = tf.one_hot(tf.reshape(test_y,(-1,2)),depth=3)
    
#     if not ONLY_MEASURE_CHANGES or np.any(blue_ts != y_oh):

    if np.any(blue_ts != y_oh):
        red_change_test_indices.append(i)
    else:
        no_change_indices.append(i)
        
    loss_val = loss_func(y_encoded, out)
    loss(loss_val)

    y_tree_vis = TrueStateTreeGraphViz(y_oh)

    predictions = []
    pred_tree_vis = []
    pred_corrects = []
    for i in range(num_eval_samples):
        total += 1
        sampled_out = MDN.sample_from_output(out[0].numpy(), output_dim=STATE_SIZE, num_mixes=NUMBER_MIXTURES)
        pred_oh = decode_z(m2, sampled_out)

        predictions.append(pred_oh)
        pred_tree_vis.append(TrueStateTreeGraphViz(pred_oh))

        diffs = np.rint(y_oh.numpy()) - np.rint(pred_oh.numpy())
        diffs_sqrd = np.sum(diffs*diffs)
        sum_diffs_sqrd += diffs_sqrd

        nodes += len(diffs.flatten())
        pred_correct = not diffs_sqrd >0
        if pred_correct:
            total_matches += 1
        pred_corrects.append(pred_correct)



    state_pred_pairs.append([y_oh, predictions])

    state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(pre_ts),
                                     TrueStateTreeGraphViz(blue_ts),
                                     y_tree_vis,
                                     pred_tree_vis])

    state_pred_samples_correct.append(pred_corrects)

state_pred_samples_correct = np.array(state_pred_samples_correct)

#         state_pred_pairs.append([y_oh, pred_oh])
#     state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(y_oh), TrueStateTreeGraphViz(pred_oh)])
#         loss(compute_loss(m2, test_x, test_y, loss_func))
#     diffs = np.rint(y_oh.numpy()) - np.rint(pred_oh.numpy())
#     diffs = get_state_diff(true_state_model,test_x)
#     nodes += len(diffs.flatten())
#     diffs_sqrd = np.sum(diffs*diffs)
#     sum_diffs_sqrd += diffs_sqrd
#     if not diffs_sqrd >0:
# #       print(diffs)
# #     else:
#       total_matches += 1
#       print("Match")
#       print(diffs)

loss = loss.result()
display.clear_output(wait=False)
print(f"accuracy = {total_matches}/{total} = {total_matches/total}, \nmean of squared diffs = {sum_diffs_sqrd}/{nodes}={sum_diffs_sqrd/nodes}\npercentage wrong = ({sum_diffs_sqrd}/{2})/({nodes}/{3})={(sum_diffs_sqrd/2)/(nodes/3)}")
# print('Epoch: {}, Test set loss: {}, time elapse for current epoch: {}'
#     .format(epoch, loss, end_time - start_time))


  and should_run_async(code)


InvalidArgumentError: Value for attr 'TI' of float is not in the list of allowed values: uint8, int8, int32, int64
	; NodeDef: {{node OneHot}}; Op<name=OneHot; signature=indices:TI, depth:int32, on_value:T, off_value:T -> output:T; attr=axis:int,default=-1; attr=T:type; attr=TI:type,default=DT_INT64,allowed=[DT_UINT8, DT_INT8, DT_INT32, DT_INT64]> [Op:OneHot]

In [None]:
import matplotlib.pyplot as plt

correct_pred_counts = np.sum(state_pred_samples_correct[no_change_indices],axis=1)

correct_pred_counts_change_only = np.sum(state_pred_samples_correct[red_change_test_indices], axis=1)

# print(f"Correct prediction frequencies: {correct_pred_counts}")

plt.hist([correct_pred_counts,correct_pred_counts_change_only],num_eval_samples+1,density=True, stacked=True, label=["No state change", "Changes only"])
# plt.hist(correct_pred_counts_change_only,num_eval_samples+1,density=True)
plt.legend(loc="lower right")
plt.xlabel('Count of correct predictions in 10 samples')
plt.ylabel('Density')
plt.show()

In [None]:
from importlib import reload
import true_state_viewer
reload(true_state_viewer)
from true_state_viewer import display_tree_red_preds

In [None]:
display_tree_red_preds(state_pred_pair_tree_vis)