In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
import keras
from keras import layers
import numpy as np
import matplotlib.pyplot as plt
import gym

In [6]:
from vae import VAE, create_decoder, create_encoder
from transition_gru import TransitionGRU
from recurrent_agent import DAIFAgentRecurrent

In [7]:
from util import random_observation_sequence, transform_observations

In [14]:
enc = create_encoder(2, 2, [20])
dec = create_decoder(2, 2, [20])
# tran = TransitionGRU(2, 1)

env = gym.make('MountainCarContinuous-v0')

In [141]:
import tensorflow as tf
import tensorflow_probability as tfp
import keras
from keras import layers
import numpy as np


class TransitionGRU(keras.Model):

    def __init__(self, latent_dim, action_dim, seq_length, hidden_units, output_dim, batch_size=None, **kwargs):
        super(TransitionGRU, self).__init__(**kwargs)


        self.latent_dim = latent_dim
        self.action_dim = action_dim
        self.seq_length = seq_length
        self.hidden_units = hidden_units
        self.output_dim = output_dim

        self.batch_size = batch_size  # this should be number of policies I think


        inputs = layers.Input(shape=(None, self.latent_dim + self.action_dim))
        initial_state_input = layers.Input((self.hidden_units, ))
        h_states, final_state = layers.GRU(self.hidden_units, activation="tanh", return_sequences=True, return_state=True, name="gru")(inputs, initial_state=initial_state_input)

        # TODO is this correctly getting the last hidden state or the first???
        z_mean = layers.Dense(latent_dim, name="z_mean")(final_state)  # all batch last time step all dimension
        z_log_sd = layers.Dense(latent_dim, name="z_log_sd")(final_state)
        z_stddev = tf.exp(z_log_sd)

        self.transition_model = keras.Model([inputs, initial_state_input], [z_mean, z_stddev, final_state, h_states], name="transition")

        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")


    def call(self, inputs, training=None, mask=None):

        # extract the initial state and
        x, initial_state = inputs
        if initial_state is None:
            initial_state = np.zeros((x.shape[0], self.hidden_units))  # start as zeros with number of examples times hidden dimension
        return self.transition_model([x] + [initial_state])


    @property
    def metrics(self):
        return [self.kl_loss_tracker]

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        inputs, targets = data
        mu, stddev = targets
        x, init_states = inputs

        with tf.GradientTape() as tape:
            z_mean, z_stddev, final_state, h_states = self.transition_model([x, init_states], training=True)  # Forward pass

            # Compute the loss value
            pred_dist = tfp.distributions.MultivariateNormalDiag(loc=z_mean, scale_diag=z_stddev)
            true_dist = tfp.distributions.MultivariateNormalDiag(loc=mu, scale_diag=stddev)

            # TODO make sure this is the correct order of terms
            kl_loss = tfp.distributions.kl_divergence(pred_dist, true_dist)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(kl_loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "kl_loss": self.kl_loss_tracker.result()
        }


In [178]:
num_seqs = 1200
seq_length = 15
ob_seqs = []
next_obs = []

for i in range(num_seqs):
    o, a, r = random_observation_sequence(env, seq_length)

    train = np.concatenate([o[:-1], a], axis=1)
    test = o[-1]

    ob_seqs.append(train)
    next_obs.append(test)

ob_seqs = np.array(ob_seqs)
next_obs = np.array(next_obs)
ob_seqs.shape

ob_seqs_stddev = np.ones_like(ob_seqs)
next_obs_stddev = np.ones_like(next_obs)

ob_seqs.shape

(1200, 15, 3)

In [179]:
next_obs.shape

(1200, 2)

In [180]:
next_obs.shape

(1200, 2)

In [155]:
m = TransitionGRU(2, 1, 10, 30, 2)

m.compile(optimizer="Adam")
# m.build((None, None, 3))
# m.summary()
init_state = np.ones((1200, 30))

In [156]:
m.fit((ob_seqs, init_state), (next_obs, next_obs_stddev), batch_size=12, epochs=30)

Epoch 1/30


2022-07-13 10:39:53.999260: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-07-13 10:39:54.111664: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.
2022-07-13 10:39:54.176776: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.


Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


<keras.callbacks.History at 0x2caaa9970>

In [159]:
res = m((ob_seqs[0:10], None))
res[0]

<tf.Tensor: shape=(10, 2), dtype=float32, numpy=
array([[-0.5678598 , -0.00945393],
       [-0.55901635, -0.012023  ],
       [-0.5301562 , -0.01353312],
       [-0.59359   , -0.00653425],
       [-0.57035166, -0.01022513],
       [-0.55035514, -0.00929178],
       [-0.5918224 , -0.00945896],
       [-0.5339744 , -0.01118272],
       [-0.5343139 , -0.00886616],
       [-0.56937265, -0.00513634]], dtype=float32)>

In [172]:
res[3]

<tf.Tensor: shape=(1200, 15, 30), dtype=float32, numpy=
array([[[ 0.02425387,  0.0244065 , -0.0065957 , ...,  0.03684389,
          0.04408829,  0.01370949],
        [ 0.03664604,  0.05243195, -0.02360114, ...,  0.06099832,
          0.08207779,  0.03359329],
        [-0.02266573,  0.05749463, -0.0032896 , ...,  0.05033551,
          0.03671521, -0.00574399],
        ...,
        [ 0.11681072,  0.2025671 , -0.22530966, ...,  0.16079979,
          0.29863134,  0.2289128 ],
        [ 0.1343703 ,  0.21245562, -0.24117123, ...,  0.14922248,
          0.3282653 ,  0.2420348 ],
        [-0.01704619,  0.17420371, -0.16205727, ...,  0.07216896,
          0.1793234 ,  0.11629379]],

       [[ 0.12727839,  0.06734122, -0.08327793, ...,  0.0805898 ,
          0.1621858 ,  0.12308379],
        [ 0.04225383,  0.06653474, -0.03929769, ...,  0.04112265,
          0.0984662 ,  0.05112836],
        [-0.04273536,  0.05595212,  0.00146029, ...,  0.02969554,
          0.01936543, -0.01734672],
        ...

In [164]:
ob_seqs[0:10, 14, :].shape

(10, 3)

In [168]:
res = m((ob_seqs, None))
res[2]

<tf.Tensor: shape=(1200, 30), dtype=float32, numpy=
array([[-0.01704619,  0.17420371, -0.16205727, ...,  0.07216896,
         0.1793234 ,  0.11629379],
       [-0.02839059,  0.15845239, -0.14379501, ...,  0.08374628,
         0.14989465,  0.10145168],
       [ 0.06763794,  0.17193797, -0.1816722 , ...,  0.14947723,
         0.22297955,  0.18303703],
       ...,
       [ 0.1104444 ,  0.20756654, -0.22950043, ...,  0.16979946,
         0.29265392,  0.2326007 ],
       [-0.15274253,  0.11597212, -0.05101934, ...,  0.10283384,
        -0.0188563 , -0.00447916],
       [ 0.01144606,  0.16745105, -0.16296513, ...,  0.10915829,
         0.1965019 ,  0.13179201]], dtype=float32)>

In [171]:
m((ob_seqs[0:1, :, :], None))

[<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[-0.56785977, -0.0094539 ]], dtype=float32)>,
 <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.99687016, 0.99474823]], dtype=float32)>,
 <tf.Tensor: shape=(1, 30), dtype=float32, numpy=
 array([[-0.0170462 ,  0.17420371, -0.16205727,  0.1459884 , -0.31527486,
         -0.00659207,  0.0312416 ,  0.0813532 ,  0.11589594, -0.34983537,
          0.09985405,  0.24849527, -0.12736231,  0.2578114 ,  0.06988851,
          0.02357278, -0.15019439,  0.01852169, -0.2631799 ,  0.25204396,
         -0.2779762 , -0.22705497,  0.35386953, -0.05633114,  0.06359442,
         -0.03013663, -0.14228506,  0.07216896,  0.1793234 ,  0.11629379]],
       dtype=float32)>,
 <tf.Tensor: shape=(1, 15, 30), dtype=float32, numpy=
 array([[[ 0.02425387,  0.0244065 , -0.0065957 ,  0.05379898,
          -0.07875035, -0.04081128, -0.02774346,  0.00961666,
           0.01717799, -0.05475242,  0.01146767,  0.09475566,
          -0.0658425 ,  0.04575911,  0.0

In [340]:
res[2]

<tf.Tensor: shape=(10, 30), dtype=float32, numpy=
array([[ 1.76516443e-01,  1.54475287e-01, -5.12306057e-02,
         1.64845765e-01,  2.68837273e-01, -9.95857716e-02,
         7.79657215e-02, -1.10403851e-01, -1.08879849e-01,
        -1.22801609e-01,  3.41023579e-02, -8.55413154e-02,
         1.35252342e-01, -5.53389899e-02, -2.05101758e-01,
         9.60189849e-02, -2.62633413e-01,  2.84929663e-01,
         2.07988352e-01, -1.35656670e-01,  1.67700097e-01,
        -9.03860945e-03,  9.67715830e-02,  1.41193375e-01,
        -1.73646688e-01, -1.49096355e-01,  1.63787216e-01,
         1.34333670e-01, -1.35705829e-01, -1.85656726e-01],
       [-4.13602814e-02,  1.25444561e-01,  1.13072321e-01,
         5.97583316e-02,  1.48286954e-01, -1.09492682e-01,
        -6.71713203e-02, -1.93654671e-01, -2.18552351e-01,
        -1.57539248e-01,  1.89289406e-01, -2.24954244e-02,
         1.34446856e-03, -4.91290167e-03, -2.39228070e-01,
         1.89485967e-01, -3.58797014e-01,  2.12003320e-01,
     

In [341]:
res[3]

<tf.Tensor: shape=(10, 5, 30), dtype=float32, numpy=
array([[[ 0.08676665,  0.0665729 , -0.01751424, ...,  0.04447326,
         -0.01017037, -0.07514461],
        [ 0.04867286,  0.10080533,  0.05040885, ...,  0.05720273,
         -0.10691113, -0.10694263],
        [ 0.15783645,  0.12385118, -0.07167391, ...,  0.10117968,
         -0.06072072, -0.15788542],
        [ 0.20955351,  0.14058515, -0.09838694, ...,  0.12855455,
         -0.07163367, -0.1843247 ],
        [ 0.17651644,  0.15447529, -0.05123061, ...,  0.13433367,
         -0.13570583, -0.18565673]],

       [[ 0.15751047,  0.06191986, -0.11045828, ...,  0.06034369,
          0.04968761, -0.08267669],
        [ 0.08696802,  0.09508388,  0.00635655, ...,  0.06304765,
         -0.08224013, -0.10179211],
        [-0.00656156,  0.11269982,  0.09019252, ...,  0.05995458,
         -0.1843041 , -0.10718634],
        [-0.00765619,  0.12198227,  0.07892087, ...,  0.07446641,
         -0.1967938 , -0.12551175],
        [-0.04136028,  0.12

In [255]:
# def gru(input_dim, seq_length, hidden_units, output_dim):
#
#     inputs = layers.Input(shape=(None, input_dim), batch_size=10)
#     out_states, h_states, *everything_else = layers.GRU(hidden_units, activation="tanh", stateful=True, return_sequences=True)(inputs)
#     h = layers.Dense(output_dim)(out_states)
#
#     model = keras.Model(inputs, h)
#
#     return model, h_states


input_dim = 3
hidden_units = 30
seq_length = 10
output_dim = 2


inputs = layers.Input(shape=(None, input_dim))
out_states = layers.GRU(2, activation="tanh", stateful=False, return_sequences=True)(inputs)
h = layers.Dense(output_dim)(out_states)

m = keras.Model(inputs, [h, out_states])

inputs2 = layers.Input(shape=(None, input_dim), batch_size=20)
out_states2 = layers.GRU(2, activation="tanh", stateful=False, return_sequences=False)(inputs2)
h2 = layers.Dense(output_dim)(out_states2)

m2 = keras.Model(inputs2, h2)

In [256]:
# m, h_states = gru(3, 10, 30, 2)

m.compile(optimizer="Adam", loss=tf.keras.losses.MeanSquaredError())
m.summary()

Model: "model_47"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_86 (InputLayer)       [(None, None, 3)]         0         
                                                                 
 gru_77 (GRU)                (None, None, 2)           42        
                                                                 
 dense_37 (Dense)            (None, None, 2)           6         
                                                                 
Total params: 48
Trainable params: 48
Non-trainable params: 0
_________________________________________________________________


In [237]:
m2.compile(optimizer="Adam", loss=tf.keras.losses.MeanSquaredError())
m2.summary()

Model: "model_40"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_77 (InputLayer)       [(20, None, 3)]           0         
                                                                 
 gru_68 (GRU)                (20, 2)                   42        
                                                                 
 dense_30 (Dense)            (20, 2)                   6         
                                                                 
Total params: 48
Trainable params: 48
Non-trainable params: 0
_________________________________________________________________


In [257]:
m.fit(ob_seqs, next_obs, batch_size=20, epochs=10)

Epoch 1/10


ValueError: in user code:

    File "/Users/Ethan/miniconda3/envs/tf_daif/lib/python3.8/site-packages/keras/engine/training.py", line 1021, in train_function  *
        return step_function(self, iterator)
    File "/Users/Ethan/miniconda3/envs/tf_daif/lib/python3.8/site-packages/keras/engine/training.py", line 1010, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/Users/Ethan/miniconda3/envs/tf_daif/lib/python3.8/site-packages/keras/engine/training.py", line 1000, in run_step  **
        outputs = model.train_step(data)
    File "/Users/Ethan/miniconda3/envs/tf_daif/lib/python3.8/site-packages/keras/engine/training.py", line 860, in train_step
        loss = self.compute_loss(x, y, y_pred, sample_weight)
    File "/Users/Ethan/miniconda3/envs/tf_daif/lib/python3.8/site-packages/keras/engine/training.py", line 918, in compute_loss
        return self.compiled_loss(
    File "/Users/Ethan/miniconda3/envs/tf_daif/lib/python3.8/site-packages/keras/engine/compile_utils.py", line 201, in __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    File "/Users/Ethan/miniconda3/envs/tf_daif/lib/python3.8/site-packages/keras/losses.py", line 141, in __call__
        losses = call_fn(y_true, y_pred)
    File "/Users/Ethan/miniconda3/envs/tf_daif/lib/python3.8/site-packages/keras/losses.py", line 245, in call  **
        return ag_fn(y_true, y_pred, **self._fn_kwargs)
    File "/Users/Ethan/miniconda3/envs/tf_daif/lib/python3.8/site-packages/keras/losses.py", line 1329, in mean_squared_error
        return backend.mean(tf.math.squared_difference(y_pred, y_true), axis=-1)

    ValueError: Dimensions must be equal, but are 10 and 20 for '{{node mean_squared_error/SquaredDifference}} = SquaredDifference[T=DT_FLOAT](model_47/dense_37/BiasAdd, IteratorGetNext:1)' with input shapes: [20,10,2], [20,2].


In [247]:
res = m(ob_seqs[0:20])

In [248]:
res[0]

<tf.Tensor: shape=(20, 10, 2), dtype=float32, numpy=
array([[[-0.03126741, -0.05007739],
        [-0.02838155, -0.00308725],
        [-0.03572483, -0.00703882],
        [-0.12308189, -0.33903465],
        [-0.09099328, -0.21175347],
        [-0.15656376, -0.48039535],
        [-0.17280677, -0.56211597],
        [-0.09679789, -0.25665036],
        [-0.09596224, -0.23660131],
        [-0.07447534, -0.14166676]],

       [[-0.09012927, -0.25503004],
        [-0.06786623, -0.12630445],
        [-0.06517889, -0.0879531 ],
        [-0.0753606 , -0.10781813],
        [-0.15388907, -0.41995582],
        [-0.16208252, -0.4539302 ],
        [-0.09656087, -0.1943801 ],
        [-0.12658402, -0.30111074],
        [-0.09610697, -0.17922345],
        [-0.1653356 , -0.4642073 ]],

       [[ 0.00175663,  0.08846912],
        [-0.10092825, -0.24563871],
        [-0.12627089, -0.31458697],
        [-0.09916592, -0.19435771],
        [-0.07646907, -0.09792791],
        [-0.09629729, -0.16202733],
       

In [223]:
res[1]

<tf.Tensor: shape=(10, 10, 2), dtype=float32, numpy=
array([[[-0.36397788,  0.0294749 ],
        [-0.48029938,  0.01573997],
        [-0.51549274, -0.00554507],
        [-0.50096333, -0.02819019],
        [-0.51849884, -0.03603037],
        [-0.4956515 , -0.04518281],
        [-0.4919299 , -0.04698842],
        [-0.5167043 , -0.04562606],
        [-0.5176949 , -0.04506196],
        [-0.5242841 , -0.04514759]],

       [[-0.3648733 ,  0.04206948],
        [-0.4807293 ,  0.04454088],
        [-0.51710445,  0.03400522],
        [-0.5260778 ,  0.02312428],
        [-0.5013063 ,  0.01332644],
        [-0.50671595,  0.01171992],
        [-0.52894664,  0.00859705],
        [-0.5185139 ,  0.00654676],
        [-0.52903926,  0.00418232],
        [-0.50056595,  0.00295787]],

       [[-0.35582137,  0.06193205],
        [-0.4635372 ,  0.05734254],
        [-0.50165814,  0.04938255],
        [-0.52453876,  0.04110029],
        [-0.5341751 ,  0.03280196],
        [-0.5287185 ,  0.02695066],
       