In [None]:
import math
import os
from typing import Tuple
from tensorflow.python.keras.utils.generic_utils import Progbar
from tf_agents.environments import suite_gym, parallel_py_environment
from tf_agents.environments import tf_py_environment
from tf_agents.metrics import tf_metrics
from tf_agents.replay_buffers import tf_uniform_replay_buffer, episodic_replay_buffer
from tf_agents.trajectories.trajectory import Trajectory
from tf_agents.networks import actor_distribution_network
from tf_agents.drivers import dynamic_episode_driver, dynamic_step_driver
from tf_agents.policies import random_tf_policy
import tensorflow as tf
import numpy as np
from tf_agents.utils import common
from reinforcement_learning.sac_training import NumberOfSafetyViolations
from tf_agents.trajectories import time_step as ts, policy_step, trajectory
from reinforcement_learning import sac_training
from reinforcement_learning import labeling_functions
labeling_function = labeling_functions['Pendulum-v0']
from util.io.dataset_generator import map_rl_trajectory_to_vae_input
from util.io.dataset_generator import ErgodicMDPTransitionGenerator

In [None]:
py_env = suite_gym.load('Pendulum-v0')
py_env.reset()
tf_env = tf_py_environment.TFPyEnvironment(py_env)
tf_env.time_step_spec()


In [None]:
def display_safe_labeling(trajectory):
    label = labeling_functions['Pendulum-v0'](trajectory.observation)
    if tf.reduce_any(label):
        print(label)

In [None]:
replay_buffer_capacity = 1280
# specs
action_spec = tf_env.action_spec()
policy_step_spec = policy_step.PolicyStep(
    action=action_spec,
    state=(),
    info=())
trajectory_spec = trajectory.from_transition(tf_env.time_step_spec(),
                                             policy_step_spec,
                                             tf_env.time_step_spec())

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=trajectory_spec,
    batch_size=tf_env.batch_size,
    max_length=replay_buffer_capacity)

dataset_generator = lambda: replay_buffer.as_dataset(
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    num_steps=2
).map(
    map_func=lambda trajectory, _: map_rl_trajectory_to_vae_input(trajectory, labeling_function),
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    #  deterministic=False  # TF version >= 2.2.0
)

In [None]:
tf_env = tf_py_environment.TFPyEnvironment(py_env)

sac_policy_dir = '../saves/Pendulum-v0/policy'
policy = tf.compat.v2.saved_model.load(sac_policy_dir)
dynamic_episode_driver.DynamicEpisodeDriver(tf_env, policy, num_episodes=5,
                                            observers=[
                                                # display_safe_labeling,
                                                lambda _: py_env.render(mode='human'),
                                                replay_buffer.add_batch
                                            ]).run()

In [None]:
dataset = replay_buffer.as_dataset(
    num_parallel_calls=tf.data.experimental.AUTOTUNE,
    num_steps=2
)
iterator = iter(dataset)
trajectory, _ = next(iterator)

state = trajectory.observation[0, ...]
labels = tf.cast(labeling_function(trajectory.observation), tf.float32)
if tf.rank(labels) == 1:
    labels = tf.expand_dims(labels, axis=-1)
label = labels[0, ...]
action = trajectory.action[0, ...]
reward = trajectory.reward[0, ...]
if tf.rank(reward) == 1:
    reward = tf.expand_dims(reward, axis=-1)
next_state = trajectory.observation[1, ...]
next_label = labels[1, ...]

print("\nstate", state)
print('\nlabels', labels)
print('\nlabel', label)
print('\naction', action)
print('\nreward', reward)
print('\nnext_state', next_state)
print('\nnext_label', next_label)

In [46]:
dataset = tf.data.Dataset.from_generator(
    generator=ErgodicMDPTransitionGenerator(
        replay_buffer=replay_buffer, labeling_function=labeling_function),
    output_types=(tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32))
iterator = iter(dataset.batch(batch_size=16))
next(iterator)

(<tf.Tensor: shape=(16, 3), dtype=float32, numpy=
 array([[ 0.96077   , -0.2773464 , -0.06566303],
        [ 0.9906935 ,  0.13611153,  0.05436661],
        [ 0.96475935, -0.26313376,  0.53738177],
        [ 0.9917752 ,  0.12799199, -0.04881461],
        [ 0.984334  , -0.17631406, -0.39580426],
        [ 0.24332136, -0.9699457 ,  1.3340402 ],
        [ 0.9820759 ,  0.18848564, -0.29833403],
        [ 0.647934  ,  0.76169646, -2.7164521 ],
        [ 0.9996509 ,  0.02642205, -0.553501  ],
        [ 0.9989885 ,  0.04496658,  0.48496607],
        [ 0.9675869 , -0.2525384 , -0.2312615 ],
        [ 0.97604245, -0.21758024, -0.2490194 ],
        [ 0.9865107 , -0.1636969 , -0.33252692],
        [ 0.98471075,  0.17419727,  0.1663308 ],
        [ 0.76377136,  0.6454869 , -1.7687432 ],
        [ 0.9827999 ,  0.18467373,  0.35009584]], dtype=float32)>,
 <tf.Tensor: shape=(16, 4), dtype=float32, numpy=
 array([[1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1.,

In [56]:
tf.random.set_seed(42)
def dataset_generator():
    dataset = tf.data.Dataset.from_generator(
        generator=ErgodicMDPTransitionGenerator(
            replay_buffer=replay_buffer, labeling_function=labeling_function),
        output_types=(tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32))
    return dataset.interleave(lambda *x: tf.data.Dataset.from_tensors(x), num_parallel_calls=16)

dataset = dataset_generator()
print(dataset.element_spec)
iterator = iter(dataset.batch(batch_size=8, drop_remainder=True))
next(iterator)

(TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), TensorSpec(shape=<unknown>, dtype=tf.float32, name=None), TensorSpec(shape=<unknown>, dtype=tf.float32, name=None))


(<tf.Tensor: shape=(8, 3), dtype=float32, numpy=
 array([[ 0.9995769 , -0.02908477, -0.31290954],
        [ 0.9592944 , -0.2824081 , -0.17525356],
        [ 0.99931717,  0.03694883,  0.31311813],
        [ 0.9999757 , -0.00697772, -0.56639445],
        [ 0.9902308 ,  0.13943815, -0.10237262],
        [ 0.9933483 , -0.11514864, -0.42249623],
        [ 0.95012456, -0.31187063,  0.43540317],
        [ 0.9668493 , -0.25534758,  0.30975214]], dtype=float32)>,
 <tf.Tensor: shape=(8, 4), dtype=float32, numpy=
 array([[1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.]], dtype=float32)>,
 <tf.Tensor: shape=(8, 1), dtype=float32, numpy=
 array([[ 0.35569933],
        [ 1.7593744 ],
        [-0.9018159 ],
        [ 0.38206947],
        [-0.8714839 ],
        [ 0.8993706 ],
        [ 1.859749  ],
        [ 1.9625181 ]], dtype=float32)>,
 <tf.Tensor: sh