In [1]:
import tensorflow as tf
import reverb
import tf_agents

from tf_agents.drivers import dynamic_step_driver, py_driver
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.policies import py_tf_eager_policy

from tf_agents.environments import suite_gym, parallel_py_environment
from tf_agents.environments import tf_py_environment, batched_py_environment
from reinforcement_learning import labeling_functions
labeling_function = labeling_functions['Pendulum-v0']



In [2]:
import policies
from util.io.dataset_generator import ErgodicMDPTransitionGenerator

env_name = 'Pendulum-v0'
py_env = suite_gym.load(env_name)
current_time_step = py_env.reset()

In [3]:
policy_dir = '../reinforcement_learning/saves/Pendulum-v0/policy'
policy = policies.SavedTFPolicy(policy_dir)
# transforms policy into a py policy because reverb replay directly interacts with py_env (and not a tf_env)
py_policy = py_tf_eager_policy.PyTFEagerPolicy(policy, use_tf_function=True)

In [4]:
table_name = 'prioritized_replay_buffer'
checkpoint_path = '../saves/checkpoint/reverb_test'

table = reverb.Table(
    table_name,
    max_size=int(1e6),
    sampler=reverb.selectors.Prioritized(priority_exponent=0.6),
    remover=reverb.selectors.MaxHeap(),
    rate_limiter=reverb.rate_limiters.MinSize(1))

checkpointer = reverb.checkpointers.DefaultCheckpointer(checkpoint_path)

reverb_server = reverb.Server([table], checkpointer=checkpointer)

In [5]:
reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
    data_spec=policy.collect_data_spec,
    sequence_length=2,
    table_name=table_name,
    local_server=reverb_server)

reverb_replay.get_table_info()

TableInfo(name='prioritized_replay_buffer', sampler_options=prioritized {
  priority_exponent: 0.6
}
, remover_options=heap {
}
is_deterministic: true
, max_size=1000000, max_times_sampled=0, rate_limiter_info=samples_per_insert: 1.0
min_diff: -1.7976931348623157e+308
max_diff: 1.7976931348623157e+308
min_size_to_sample: 1
insert_stats {
  completed_wait_time {
  }
  pending_wait_time {
  }
}
sample_stats {
  completed_wait_time {
  }
  pending_wait_time {
  }
}
, signature=None, current_size=15432, num_episodes=2, num_deleted_episodes=0)

In [6]:
batch_size = 8

dataset = reverb_replay.as_dataset(
    sample_batch_size=batch_size,
    num_steps=2,
    num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset_iterator = iter(dataset)

In [7]:
observer = reverb_utils.ReverbTrajectorySequenceObserver(
    py_client=reverb_replay.py_client,
    table_name=table_name,
    sequence_length=2,
    stride_length=1,
    priority=tf.constant(1., dtype=tf.float32))

In [8]:
num_steps = 5120
driver = py_driver.PyDriver(
    env=py_env, policy=py_policy, observers=[observer], max_steps=num_steps,)

In [9]:
driver.run(time_step=py_env.current_time_step())
reverb_replay.py_client.checkpoint()
reverb_replay.get_table_info()

TableInfo(name='prioritized_replay_buffer', sampler_options=prioritized {
  priority_exponent: 0.6
}
, remover_options=heap {
}
is_deterministic: true
, max_size=1000000, max_times_sampled=0, rate_limiter_info=samples_per_insert: 1.0
min_diff: -1.7976931348623157e+308
max_diff: 1.7976931348623157e+308
min_size_to_sample: 1
insert_stats {
  completed: 5143
  completed_wait_time {
  }
  pending_wait_time {
  }
}
sample_stats {
  completed_wait_time {
  }
  pending_wait_time {
  }
}
, signature=None, current_size=15432, num_episodes=2, num_deleted_episodes=0)

In [9]:
traj = next(dataset_iterator)
traj

(Trajectory(step_type=<tf.Tensor: shape=(8, 2), dtype=int32, numpy=
 array([[1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1]], dtype=int32)>, observation=<tf.Tensor: shape=(8, 2, 3), dtype=float32, numpy=
 array([[[ 0.7388385 , -0.67388254, -0.82901365],
         [ 0.69987863, -0.71426183, -1.1223515 ]],
 
        [[ 0.5747614 , -0.81832105, -1.6312705 ],
         [ 0.4867845 , -0.8735221 , -2.0781534 ]],
 
        [[ 0.9725067 ,  0.2328748 ,  0.23582403],
         [ 0.97024345,  0.24213134,  0.19058491]],
 
        [[ 0.9913364 , -0.13134722, -0.49659005],
         [ 0.9882166 , -0.15306222, -0.43876857]],
 
        [[ 0.9853056 , -0.17080058, -0.35951707],
         [ 0.98237467, -0.18692261, -0.3277294 ]],
 
        [[ 0.9934162 , -0.11456145,  0.6380821 ],
         [ 0.9963247 , -0.08565653,  0.581038  ]],
 
        [[ 0.6984104 ,  0.71569747, -2.2532053 ],
         [ 0.7631587 ,  0.64621115, -1.9002608 ]],
 
    

In [12]:
reverb_replay.update_priorities(traj[1].key[..., 0], 1./3 * tf.ones(shape=(batch_size,), dtype=tf.float64))

In [17]:
print(reverb_replay.get_table_info().current_size)
reverb_replay.py_client.checkpoint()

5143


'../saves/checkpoint/reverb_test/2021-04-26T18:20:07.832703724+02:00'

In [13]:
import variational_action_discretizer

vae_mdp = variational_action_discretizer.load(
    '../saves/Pendulum-v0/models/vae_LS12_MC1_ER10.0-decay=7.5e-05-min=0_KLA0.0-growth=5e-05_TD0.67-0.50_1e-06-2e-06_seed=20421/policy/action_discretizer/LA3_MC1_ER10.0-decay=7.5e-05-min=0_KLA0.0-growth=5e-05_TD0.50-0.33_1e-06-2e-06_params=full_vae_optimization-relaxed_state_encoding/base',
    step=1830000
)
print("VAE MDP loaded")



VAE MDP loaded


In [14]:

generator = ErgodicMDPTransitionGenerator(
    labeling_function=labeling_function,
    replay_buffer=reverb_replay,
    discrete_action=False,
    prioritized_replay_buffer=True,
    state_embedding_function=lambda state, label: tf.squeeze(
        vae_mdp.binary_encode(tf.expand_dims(state, axis=0), tf.expand_dims(label, axis=0)).mode()))

In [15]:
batch_size = 8

dataset = reverb_replay.as_dataset(
    num_steps=2,
    num_parallel_calls=tf.data.experimental.AUTOTUNE
).map(map_func=generator, num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset_iterator = iter(dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE))

In [16]:
next(dataset_iterator)

(<tf.Tensor: shape=(8, 3), dtype=float32, numpy=
 array([[ 0.99993664,  0.01125781,  0.5358942 ],
        [ 0.9816694 ,  0.19059162,  0.31015158],
        [-0.29377538,  0.9558745 , -5.620201  ],
        [ 0.97343564,  0.22896075, -0.34662813],
        [ 0.99237573, -0.12324932, -0.35998908],
        [ 0.9571976 ,  0.28943527, -0.4271448 ],
        [ 0.9930223 , -0.11792643,  0.67813486],
        [-0.18458012, -0.9828175 ,  4.9253883 ]], dtype=float32)>,
 <tf.Tensor: shape=(8, 4), dtype=float32, numpy=
 array([[1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [0., 0., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 0.],
        [0., 0., 0., 0.]], dtype=float32)>,
 <tf.Tensor: shape=(8, 1), dtype=float32, numpy=
 array([[-0.2599124 ],
        [-1.1269883 ],
        [-0.5246502 ],
        [-0.968116  ],
        [ 0.6597317 ],
        [-1.0723196 ],
        [ 0.32941908],
        [ 0.6956657 ]], dtype=float32)>,
 <tf.Tensor: sh

In [17]:
dataset.take(1)

<TakeDataset shapes: ((3,), (None,), (1,), <unknown>, (3,), (None,), ()), types: (tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float64)>

In [18]:
state_embedding_function = lambda state, label: tf.squeeze(
        vae_mdp.binary_encode(tf.expand_dims(state, axis=0), tf.expand_dims(label, axis=0)).mode())
generator.latent_state_size

<tf.Tensor: shape=(), dtype=int32, numpy=12>

In [24]:
tf_agents.system.multiprocessing.enable_interactive_mode()
py_env = parallel_py_environment.ParallelPyEnvironment(
    [lambda: suite_gym.load(env_name)] * 5)
py_env.reset()

TimeStep(step_type=array([0, 0, 0, 0, 0], dtype=int32), reward=array([0., 0., 0., 0., 0.], dtype=float32), discount=array([1., 1., 1., 1., 1.], dtype=float32), observation=array([[-0.98478675, -0.17376737, -0.75586486],
       [ 0.20625967,  0.97849727,  0.84964657],
       [-0.782991  ,  0.62203306, -0.08362035],
       [ 0.6861767 ,  0.7274349 ,  0.98699635],
       [-0.16147746,  0.9868764 , -0.8903516 ]], dtype=float32))

In [10]:
observer.close()
reverb_server.stop()