### MDN References:

[useful pytorch reference](https://github.com/tonyduan/mixture-density-network)

[keras version](https://github.com/cpmpercussion/keras-mdn-layer)

[another keras version](https://github.com/omimo/Keras-MDN/blob/master/kmdn/mdn.py)


In [None]:
import MDN
from MDN import sample_from_output
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K
import pandas as pd
from ProcessTrueStateActionData import read_df_in_chunks

from true_state_viewer import TrueStateTreeGraphViz, display_tree_red_preds

import time
from IPython import display

### Load and preprocess the data 
(produce tf train+test datasets)

In [None]:
batch_size = 32
train_test_split = 0.98

STATE_SIZE = 42 # (mdn output_dimension)
NUMBER_MIXTURES = 10

EPOCHS = 50

DATA_CAP = 3_000_000

In [None]:
data_path = 'logs/APPO/TrueStates_200_1000_Meander_small/data'
states = np.load(data_path + '/states.npy')
afterstates = np.load(data_path + '/afterstates.npy')
next_states = np.load(data_path + '/next_states.npy')
actions = np.load(data_path + '/actions_onehot.npy')
rewards = np.load(data_path + '/rewards.npy')

states_actions = np.concatenate([states, actions], axis=1)
next_states_delta = next_states - afterstates

afterstates = np.array(afterstates, dtype=np.float32)
next_states = np.array(next_states, dtype=np.float32)
states_actions = np.array(states_actions, dtype=np.float32)

divide = int(train_test_split * afterstates.shape[0])
afterstates_train = afterstates[:divide,:]
next_states_train = next_states[:divide,:]
afterstates_test = afterstates[divide:,:]
next_states_test = next_states[divide:,:]
print(afterstates.shape[0] - divide)

In [None]:
states[0]

In [None]:
tot = 0
for i in range(10000):
    tot += np.sum(np.abs(next_states_delta_train[i]))/2
print(tot/10000)

In [None]:
tot = 0
for i in range(10000):
    tot += np.sum(np.abs(next_states_delta_train[i]))/2
print(tot/10000)

In [None]:
train_dataset = (tf.data.Dataset.from_tensor_slices((afterstates_train, next_states_train)).shuffle(afterstates.shape[0]).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices((afterstates_test, next_states_test)).shuffle(afterstates.shape[0]).batch(batch_size))

### Create an MDN based model with pretrained encoder/decoder layers

In [None]:
class AfterstatePrediction(tf.keras.Model):

    def __init__(self, state_size, num_mixtures):
        super().__init__()
        
       # self.cross_dense = tf.keras.layers.Dense(1024, activation=tf.nn.relu)
        
        self.fc_1 = tf.keras.layers.Dense(42, activation=tf.nn.sigmoid)

        self.mdn = MDN.MDN(output_dimension=state_size, num_mixtures=num_mixtures)


    @tf.function#(input_signature=[tf.TensorSpec(shape=[None,26], dtype=tf.uint8),tf.TensorSpec(shape=[None,26], dtype=tf.uint8)])
    def call(self, inputs):

        #combined_hidden = self.cross_dense(inputs)
        
        #fc = self.fc_1(combined_hidden)

        mdn_out = self.mdn(inputs)

        return mdn_out
    
    #def decode(self, latent):
    #    return self.fc_1(latent)
        


In [None]:
@tf.function
def forward_pass(model, x, y):
    out = model(x)
    
    return out, y#model.decode(out)


@tf.function
def compute_loss(model, x, y, loss_func):
    out, y = forward_pass(model, x, y)
    
    loss = loss_func(y, out)
    
    return loss
    

@tf.function
def train_step(model, x, y, optimizer, loss_func):
    """Executes one training step and returns the loss.

    This function computes the loss and gradients, and uses the latter to
    update the model's parameters.
    """
#     y_encoded = model.encode(y)
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x, y, loss_func)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

### Training loop

In [None]:
from keras.layers import Bidirectional
from keras.models import Sequential
from keras.layers.core import Activation, Dropout, Dense
from keras.layers import Flatten, LSTM, Input
from keras.layers import Bidirectional


model = Sequential()
#model.add(Bidirectional(LSTM(256, activation='relu', return_sequences=True), input_shape=(sequence_length, encoding_len)))
#model.add(Flatten())
model.add(Input(shape=(42+20,)))
model.add(Dense(512, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(42, activation='sigmoid'))
model.compile(optimizer='adam', loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.BinaryAccuracy()])

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=7)

from keras import backend as K
K.set_value(model.optimizer.learning_rate, 0.0005)

print(model.summary())

#model.fit(train_dataset,epochs=250, verbose=1, callbacks=[callback])

model.fit(states_actions, afterstates, epochs=250, validation_split=0.1, verbose=1, callbacks=[callback], batch_size=64)

In [None]:
model.save_weights('AfterStateModel')

In [17]:
model = Sequential()
#model.add(Bidirectional(LSTM(256, activation='relu', return_sequences=True), input_shape=(sequence_length, encoding_len)))
#model.add(Flatten())
model.add(Input(shape=(42+20,)))
model.add(Dense(512, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(42, activation='sigmoid'))
model.compile(optimizer='adam', loss=tf.keras.losses.BinaryCrossentropy(), metrics=[tf.keras.metrics.BinaryAccuracy()])

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=7)

from keras import backend as K
K.set_value(model.optimizer.learning_rate, 0.00005)

print(model.summary())

#model.fit(train_dataset,epochs=250, verbose=1, callbacks=[callback])

model.fit(states_actions, next_states, epochs=250, validation_split=0.1, verbose=1, callbacks=[callback], batch_size=64)

Epoch 10/250
Epoch 11/250
Epoch 12/250
Epoch 13/250
Epoch 14/250
Epoch 15/250
Epoch 16/250
Epoch 17/250
Epoch 18/250
Epoch 19/250
Epoch 20/250
Epoch 21/250
Epoch 22/250
Epoch 23/250
Epoch 24/250
Epoch 25/250
Epoch 26/250
Epoch 27/250
Epoch 28/250
Epoch 29/250
Epoch 30/250
Epoch 31/250
Epoch 32/250
Epoch 33/250
Epoch 34/250
 611/2785 [=====>........................] - ETA: 9s - loss: 0.1037 - binary_accuracy: 0.9467

KeyboardInterrupt: 

In [30]:
model.save_weights('NextStateModel')

In [None]:
labels, encoding, counts = np.unique(rewards, return_inverse=True, return_counts=True)
encoding = np.eye(labels.shape[0])[encoding]
reward_map = {}
for i in range(labels.shape[0]):
    reward_map[i] = labels[i]
reward_map

In [None]:
counts

In [None]:
encoding

In [None]:
model = Sequential()
#model.add(Bidirectional(LSTM(256, activation='relu', return_sequences=True), input_shape=(sequence_length, encoding_len)))
#model.add(Flatten())
model.add(Input(shape=(42,)))
model.add(Dense(512, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dense(6, activation='softmax'))
model.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy(), metrics=[tf.keras.metrics.CategoricalAccuracy()])

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=7)

from keras import backend as K
K.set_value(model.optimizer.learning_rate, 0.005)

print(model.summary())

#model.fit(train_dataset,epochs=250, verbose=1, callbacks=[callback])

model.fit(next_states, encoding, epochs=250, validation_split=0.1, verbose=1, callbacks=[callback], batch_size=64)

In [None]:
model.predict(np.zeros((1,42)))

In [None]:
model.save_weights('RewardModel')

In [None]:
np.save('reward_map.npy', reward_map) 

In [None]:
for i in range(10):
    print(next_states_delta_train[i].sum())

In [18]:
predictions = np.round(model.predict(afterstates_test))
count = 0
for i in range(3960):
    if (predictions[i]==next_states_test[i]).all():
        count += 1
    if np.sum(predictions[i]-next_states_test[i]) != 0:
        print(np.sum(predictions[i]-next_states_test[i]))
count

ValueError: in user code:

    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2137, in predict_function  *
        return step_function(self, iterator)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2123, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2111, in run_step  **
        outputs = model.predict_step(data)
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/training.py", line 2079, in predict_step
        return self(x, training=False)
    File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 70, in error_handler
        raise e.with_traceback(filtered_tb) from None
    File "/usr/local/lib/python3.8/dist-packages/keras/engine/input_spec.py", line 295, in assert_input_compatibility
        raise ValueError(

    ValueError: Input 0 of layer "sequential_3" is incompatible with the layer: expected shape=(None, 62), found shape=(None, 42)


In [19]:
prediction = model.predict(np.array([states_actions[0]]))[0]

new_state = {}

for i in range(3960):
    while True:
        state = np.array(np.random.rand(42) < prediction, dtype=np.int8)
        if int(np.sum(state-next_states_test[i])) == 0:
            break
        #else:
        #    print(int(np.sum(state-next_states_test[i])))
    if not state.tobytes() in new_state:
        new_state[state.tobytes()] = 1
    else:
        new_state[state.tobytes()] += 1
        
ns = {k: v for k, v in sorted(new_state.items(), key=lambda item: item[1], reverse=True)}
            
for s in ns.keys():
    print(ns[s])
    print(np.frombuffer(s, dtype="int8"))

print(len(new_state.keys()))

799
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0
 1 0 0 0 1]
86
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 1 0 0
 1 0 1 0 1]
81
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0
 1 0 0 0 1]
65
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1 1 0 0 0 1 0 0
 0 0 0 0 1]
65
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 1 1 0
 1 0 0 0 1]
61
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 1 0 0
 1 0 0 0 1]
52
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 1 0 0
 1 0 0 0 1]
52
[0 1 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 1 0 0
 1 0 0 0 1]
51
[0 0 1 0 0 0 1 1 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 1 0 0
 1 0 0 0 1]
50
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0
 1 0 1 0 0]
49
[0 0 1 0 0 0 1 0 0 1 0 1 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 1 0 0
 1 0 0 0 1]
49
[0 0 1

2
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 0 0 1 0
 0 0 1 0 1]
2
[0 0 1 0 0 0 1 0 0 1 0 1 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 0 0 0
 1 0 1 0 0]
2
[0 0 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 1 0 1 0 0 0 0 1 0
 1 0 0 0 1]
2
[0 1 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 1 1 0
 1 0 0 0 1]
2
[0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 1 0 1 0 0 0 1 1 0
 1 0 0 0 1]
2
[0 0 1 0 0 0 1 0 0 1 0 1 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 1 0 1
 0 0 0 0 1]
2
[0 1 1 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 1 1 0
 1 0 0 0 1]
2
[0 1 1 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0
 0 0 1 0 1]
2
[0 0 1 0 0 0 1 1 0 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 1 0 0
 1 0 0 0 1]
2
[0 1 1 0 0 0 1 1 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 1 0 0
 1 0 0 0 0]
2
[0 0 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 1 0 1 0 0 0 1 0 0
 0 0 1 0 1]
2
[0 0 0 0 0 0 1 0 0 

1
[0 0 1 0 0 0 1 1 0 1 0 1 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0
 0 0 0 0 1]
1
[0 0 1 0 0 0 1 0 1 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 1 0 0 0 0
 1 0 1 0 0]
1
[0 1 1 0 0 0 1 0 0 1 0 1 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0
 0 0 0 0 0]
1
[0 0 0 0 0 0 1 0 0 0 0 1 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1 1 0 0 0 1 0 0
 1 0 0 0 1]
1
[0 0 0 0 0 0 1 1 0 1 0 1 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 1 0 0
 0 0 0 0 1]
1
[0 0 1 0 0 0 1 0 0 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 1 0 1
 1 0 0 0 1]
1
[0 0 1 0 0 0 1 0 1 1 0 0 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 0 0 1 1 0
 0 0 0 0 1]
1
[0 0 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 1 0 0 1 0 0
 1 0 0 0 1]
1
[0 0 1 0 0 0 1 1 1 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0
 1 0 1 0 0]
1
[0 1 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0
 0 0 1 0 0]
1
[0 0 1 0 0 0 1 1 1 1 0 0 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 0 0 0 0 1 0 0
 1 0 0 0 1]
1
[0 0 1 0 0 0 1 0 0 

In [29]:
target = states_actions[0]
new_state = {}

for i in range(194040):
    if (target==states_actions[i]).all():
        state = np.array(next_states[i], dtype=np.int8)
        if not state.tobytes() in new_state:
            new_state[state.tobytes()] = 1
        else:
            new_state[state.tobytes()] += 1
        
ns = {k: v for k, v in sorted(new_state.items(), key=lambda item: item[1], reverse=True)}
            
for s in ns.keys():
    print(ns[s])
    print(np.frombuffer(s, dtype="int8"))

print(len(new_state.keys()))

93
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0
 1 0 0 0 1]
64
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0
 1 0 0 0 1]
63
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 1 0
 0 0 1 0 0]
54
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0
 1 0 0 0 1]
44
[0 1 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0
 1 0 0 0 1]
42
[0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0
 1 0 0 0 1]
30
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1
 0 0 0 0 1]
7
[0 1 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0
 1 0 0 0 1]
7
[0 0 1 0 0 0 1 1 0 0 0 1 0 0 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0
 1 0 0 0 1]
7
[0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 1
 0 0 0 0 1]
5
[0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0 1 0 0 0 1 0 0 1 0 0 0 1 0 1 0 0 0 0 1 0 0
 1 0 0 0 1]
5
[0 0 1 0 0 0

In [None]:
afterstates_train

In [None]:
next_states_train[i].shape

In [None]:
afterstate_predictor = AfterstatePrediction(state_size=STATE_SIZE, num_mixtures=NUMBER_MIXTURES)

In [None]:
for row in test_dataset.take(1):
    out = afterstate_predictor(row[0])
    print(out)

In [None]:
# red_ts_predictor.compile(loss=get_mixture_loss_func(LATENT_SIZE,NUMBER_MIXTURES,red_ts_predictor.encode), optimizer=tf.keras.optimizers.Adam(),metrics=['mean_squared_error'])
# red_ts_predictor.build(((1,78),(1,78)))
afterstate_predictor.summary()


In [None]:
loss_func = MDN.get_mixture_loss_func(STATE_SIZE,NUMBER_MIXTURES)
optimizer = tf.keras.optimizers.Adam(.1e-5)

In [None]:

# EPOCHS = 500
for epoch in range(1, EPOCHS + 1):
    count = 0
    start_time = time.time()
    for train_x, train_y in train_dataset:
        #if (count %1000) == 0:
        #    print(f"{count}={count*batch_size} samples")
        train_step(afterstate_predictor, train_x, train_y, optimizer, loss_func)
        count += 1
    end_time = time.time()

    
    total_matches = 0
    total = 0
    nodes = 0
    sum_diffs_sqrd = 0
    state_pred_pairs = []
    state_pred_pair_tree_vis = []
    loss = tf.keras.metrics.Mean()
    loss_result, total = 0, 0
    for test_x, test_y in test_dataset.take(1000):#.take(10):#test_dataset:
        total += 1
        out, y = forward_pass(afterstate_predictor, test_x, test_y)
        loss_val = loss_func(y, out)
        loss(loss_val)
        loss_result += loss.result()
        #     print(f"accuracy = {total_matches}/{total} = {total_matches/total}, \nmean of squared diffs = {sum_diffs_sqrd}/{nodes}={sum_diffs_sqrd/nodes}\npercentage wrong = ({sum_diffs_sqrd}/{2})/({nodes}/{3})={(sum_diffs_sqrd/2)/(nodes/3)}")
    print('Epoch: {}, Test set loss: {}, time elapse for current epoch: {}'
        .format(epoch, loss_result/total, end_time - start_time))
        
#         sampled_out = MDN.sample_from_output(out[0].numpy(), output_dim=LATENT_SIZE, num_mixes=NUMBER_MIXTURES)
#         pred_oh, y_oh = decode_zs(red_ts_predictor, sampled_out, y_encoded)
        
        
#         state_pred_pairs.append([y_oh, pred_oh])
# #         state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(y_oh), TrueStateTreeGraphViz(pred_oh)])
# #         loss(compute_loss(red_ts_predictor, test_x, test_y, loss_func))
#         diffs = np.rint(y_oh.numpy()) - np.rint(pred_oh.numpy())
#     #     diffs = get_state_diff(true_state_model,test_x)
#         nodes += len(diffs.flatten())
#         diffs_sqrd = np.sum(diffs*diffs)
#         sum_diffs_sqrd += diffs_sqrd
#         if not diffs_sqrd >0:
#     #       print(diffs)
#     #     else:
#           total_matches += 1
#     #       print("Match")
    #       print(diffs)


In [None]:
# pre_ts_encoded = red_ts_predictor.ts_vae.encode(np.zeros((1,78),dtype=np.float32))
# red_ts_predictor.ts_dense(pre_ts_encoded)

np.array(sample_from_output(afterstate_predictor(np.zeros((1,42)))[0].numpy(), 42, 5)>0.5, dtype=np.int8)

#afterstate_predictor.save('models/afterstate_predictor',overwrite=True)

### Test and evaluate

In [None]:
m2 = tf.keras.models.load_model(
    'models/afterstate_predictor',
)

In [None]:
total_matches = 0
total = 0
nodes = 0
sum_diffs_sqrd = 0
state_pred_pairs = []
state_pred_pair_tree_vis = []
state_pred_samples_correct = []

red_change_test_indices = []
no_change_indices = []

loss = tf.keras.metrics.Mean()

num_eval_samples = 10

for i, (test_x, test_y) in enumerate(test_dataset.take(3000)):#.take(10):#test_dataset:
    pre_ts = tf.reshape(tf.one_hot(test_x[0],3),(-1,2,3))
    blue_ts = tf.reshape(tf.one_hot(test_x[1],3),(-1,2,3))
    
    out, y_encoded = forward_pass(m2, test_x, test_y)
#     y_oh = decode_z(m2, y_encoded)
    
    y_oh = tf.one_hot(tf.reshape(test_y,(-1,2)),depth=3)
    
#     if not ONLY_MEASURE_CHANGES or np.any(blue_ts != y_oh):

    if np.any(blue_ts != y_oh):
        red_change_test_indices.append(i)
    else:
        no_change_indices.append(i)
        
    loss_val = loss_func(y_encoded, out)
    loss(loss_val)

    y_tree_vis = TrueStateTreeGraphViz(y_oh)

    predictions = []
    pred_tree_vis = []
    pred_corrects = []
    for i in range(num_eval_samples):
        total += 1
        sampled_out = MDN.sample_from_output(out[0].numpy(), output_dim=STATE_SIZE, num_mixes=NUMBER_MIXTURES)
        pred_oh = decode_z(m2, sampled_out)

        predictions.append(pred_oh)
        pred_tree_vis.append(TrueStateTreeGraphViz(pred_oh))

        diffs = np.rint(y_oh.numpy()) - np.rint(pred_oh.numpy())
        diffs_sqrd = np.sum(diffs*diffs)
        sum_diffs_sqrd += diffs_sqrd

        nodes += len(diffs.flatten())
        pred_correct = not diffs_sqrd >0
        if pred_correct:
            total_matches += 1
        pred_corrects.append(pred_correct)



    state_pred_pairs.append([y_oh, predictions])

    state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(pre_ts),
                                     TrueStateTreeGraphViz(blue_ts),
                                     y_tree_vis,
                                     pred_tree_vis])

    state_pred_samples_correct.append(pred_corrects)

state_pred_samples_correct = np.array(state_pred_samples_correct)

#         state_pred_pairs.append([y_oh, pred_oh])
#     state_pred_pair_tree_vis.append([TrueStateTreeGraphViz(y_oh), TrueStateTreeGraphViz(pred_oh)])
#         loss(compute_loss(m2, test_x, test_y, loss_func))
#     diffs = np.rint(y_oh.numpy()) - np.rint(pred_oh.numpy())
#     diffs = get_state_diff(true_state_model,test_x)
#     nodes += len(diffs.flatten())
#     diffs_sqrd = np.sum(diffs*diffs)
#     sum_diffs_sqrd += diffs_sqrd
#     if not diffs_sqrd >0:
# #       print(diffs)
# #     else:
#       total_matches += 1
#       print("Match")
#       print(diffs)

loss = loss.result()
display.clear_output(wait=False)
print(f"accuracy = {total_matches}/{total} = {total_matches/total}, \nmean of squared diffs = {sum_diffs_sqrd}/{nodes}={sum_diffs_sqrd/nodes}\npercentage wrong = ({sum_diffs_sqrd}/{2})/({nodes}/{3})={(sum_diffs_sqrd/2)/(nodes/3)}")
# print('Epoch: {}, Test set loss: {}, time elapse for current epoch: {}'
#     .format(epoch, loss, end_time - start_time))


In [None]:
import matplotlib.pyplot as plt

correct_pred_counts = np.sum(state_pred_samples_correct[no_change_indices],axis=1)

correct_pred_counts_change_only = np.sum(state_pred_samples_correct[red_change_test_indices], axis=1)

# print(f"Correct prediction frequencies: {correct_pred_counts}")

plt.hist([correct_pred_counts,correct_pred_counts_change_only],num_eval_samples+1,density=True, stacked=True, label=["No state change", "Changes only"])
# plt.hist(correct_pred_counts_change_only,num_eval_samples+1,density=True)
plt.legend(loc="lower right")
plt.xlabel('Count of correct predictions in 10 samples')
plt.ylabel('Density')
plt.show()

In [None]:
from importlib import reload
import true_state_viewer
reload(true_state_viewer)
from true_state_viewer import display_tree_red_preds

In [None]:
display_tree_red_preds(state_pred_pair_tree_vis)