In [1]:
import tensorflow as tf
import reverb

from tf_agents.drivers import dynamic_step_driver, py_driver
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.policies import py_tf_eager_policy

from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment, batched_py_environment
from reinforcement_learning import labeling_functions
labeling_function = labeling_functions['Pendulum-v0']



In [2]:
import policies
from util.io.dataset_generator import ErgodicMDPTransitionGenerator

env_name = 'Pendulum-v0'
py_env = suite_gym.load(env_name)
current_time_step = py_env.reset()

In [3]:
policy_dir = '../reinforcement_learning/saves/Pendulum-v0/policy'
policy = policies.SavedTFPolicy(policy_dir)
# transforms policy into a py policy because reverb replay directly interacts with py_env (and not a tf_env)
py_policy = py_tf_eager_policy.PyTFEagerPolicy(policy, use_tf_function=True)

In [4]:
table_name = 'prioritized_replay_buffer'
table = reverb.Table(
    table_name,
    max_size=int(1e6),
    sampler=reverb.selectors.Prioritized(priority_exponent=0.6),
    remover=reverb.selectors.MaxHeap(),
    rate_limiter=reverb.rate_limiters.MinSize(1))

reverb_server = reverb.Server([table])

In [5]:
reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
    data_spec=policy.collect_data_spec,
    sequence_length=2,
    table_name=table_name,
    local_server=reverb_server)

In [6]:
batch_size = 128

dataset = reverb_replay.as_dataset(
    sample_batch_size=batch_size,
    num_steps=2,
    num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset_iterator = iter(dataset)

In [7]:
observer = reverb_utils.ReverbTrajectorySequenceObserver(
    py_client=reverb_replay.py_client,
    table_name=table_name,
    sequence_length=2,
    stride_length=1,
    priority=tf.constant(1., dtype=tf.float32))

In [8]:
num_steps = 5120
driver = py_driver.PyDriver(
    env=py_env, policy=py_policy, observers=[observer], max_steps=num_steps,)

In [9]:
current_time_step = driver.run(time_step=current_time_step)

In [10]:
traj = next(dataset_iterator)

In [11]:
reverb_replay.update_priorities(traj[1].key[..., 0], 1./3 * tf.ones(shape=(batch_size,), dtype=tf.float64))

In [12]:
reverb_replay.get_table_info().current_size

5143

In [13]:
import variational_action_discretizer

vae_mdp = variational_action_discretizer.load(
    '../saves/Pendulum-v0/models/vae_LS12_MC1_ER10.0-decay=7.5e-05-min=0_KLA0.0-growth=5e-05_TD0.67-0.50_1e-06-2e-06_seed=20421/policy/action_discretizer/LA3_MC1_ER10.0-decay=7.5e-05-min=0_KLA0.0-growth=5e-05_TD0.50-0.33_1e-06-2e-06_params=full_vae_optimization-relaxed_state_encoding/base',
    step=1830000
)
print("VAE MDP loaded")



VAE MDP loaded


In [14]:

generator = ErgodicMDPTransitionGenerator(
    labeling_function=labeling_function,
    replay_buffer=reverb_replay,
    discrete_action=False,
    prioritized_replay_buffer=True,
    state_embedding_function=lambda state, label: tf.squeeze(
        vae_mdp.binary_encode(tf.expand_dims(state, axis=0), tf.expand_dims(label, axis=0)).mode()))

In [15]:
batch_size = 8

dataset = reverb_replay.as_dataset(
    num_steps=2,
    num_parallel_calls=tf.data.experimental.AUTOTUNE
).map(map_func=generator, num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset_iterator = iter(dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE))

In [16]:
next(dataset_iterator)

(<tf.Tensor: shape=(8, 3), dtype=float32, numpy=
 array([[ 0.99993664,  0.01125781,  0.5358942 ],
        [ 0.9816694 ,  0.19059162,  0.31015158],
        [-0.29377538,  0.9558745 , -5.620201  ],
        [ 0.97343564,  0.22896075, -0.34662813],
        [ 0.99237573, -0.12324932, -0.35998908],
        [ 0.9571976 ,  0.28943527, -0.4271448 ],
        [ 0.9930223 , -0.11792643,  0.67813486],
        [-0.18458012, -0.9828175 ,  4.9253883 ]], dtype=float32)>,
 <tf.Tensor: shape=(8, 4), dtype=float32, numpy=
 array([[1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [0., 0., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [0., 0., 0., 0.]], dtype=float32)>,
 <tf.Tensor: shape=(8, 1), dtype=float32, numpy=
 array([[-0.2599124 ],
        [-1.1269883 ],
        [-0.5246502 ],
        [-0.968116  ],
        [ 0.6597317 ],
        [-1.0723196 ],
        [ 0.32941908],
        [ 0.6956657 ]], dtype=float32)>,
 <tf.Tensor: sh

In [17]:
dataset.take(1)

<TakeDataset shapes: ((3,), (None,), (1,), <unknown>, (3,), (None,), ()), types: (tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float64)>

In [18]:
state_embedding_function = lambda state, label: tf.squeeze(
        vae_mdp.binary_encode(tf.expand_dims(state, axis=0), tf.expand_dims(label, axis=0)).mode())
generator.latent_state_size

<tf.Tensor: shape=(), dtype=int32, numpy=12>