In [1]:
import tensorflow as tf
import reverb

from tf_agents.drivers import dynamic_step_driver, py_driver
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.policies import py_tf_eager_policy

from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment, batched_py_environment



In [2]:
import policies

env_name = 'Pendulum-v0'
py_env = suite_gym.load(env_name)
py_env.reset()

TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([0.999547  , 0.0300953 , 0.62336147], dtype=float32))

In [3]:
policy_dir = '../reinforcement_learning/saves/Pendulum-v0/policy'
policy = policies.SavedTFPolicy(policy_dir)
# transforms policy into a py policy because reverb replay directly interacts with py_env (and not a tf_env)
py_policy = py_tf_eager_policy.PyTFEagerPolicy(policy, use_tf_function=True)

In [4]:
table_name = 'prioritized_replay_buffer'
table = reverb.Table(
    table_name,
    max_size=int(1e6),
    sampler=reverb.selectors.Prioritized(priority_exponent=0.6),
    remover=reverb.selectors.MaxHeap(),
    rate_limiter=reverb.rate_limiters.MinSize(1))

reverb_server = reverb.Server([table])

In [5]:
reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
    data_spec=policy.collect_data_spec,
    sequence_length=2,
    table_name=table_name,
    local_server=reverb_server)

In [6]:
batch_size = 8

dataset = reverb_replay.as_dataset(
    sample_batch_size=batch_size,
    num_steps=2,
    num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset_iterator = iter(dataset)

In [7]:
observer = reverb_utils.ReverbTrajectorySequenceObserver(
    py_client=reverb_replay.py_client,
    table_name=table_name,
    sequence_length=2,
    stride_length=1,)

In [8]:
num_steps = 1280
driver = py_driver.PyDriver(
    env=py_env, policy=py_policy, observers=[observer], max_steps=num_steps,)

In [9]:
driver.run(py_env.reset())

(TimeStep(step_type=array(1, dtype=int32), reward=array(-0.08640341, dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.95845556, -0.28524184,  0.00667524], dtype=float32)),
 ())

In [20]:
traj = next(dataset_iterator)

In [28]:
reverb_replay.update_priorities(traj[1].key[..., 0], 1./3 * tf.ones(shape=(8,), dtype=tf.float64))

In [26]:
print(traj[1].key[..., 0])
print(1./3 * tf.ones(shape=(8,), dtype=tf.float64))

tf.Tensor(
[  535561011561748336  6660948364819225248  2863702704173667460
  7855816775213045753 11241696028115428654  6754368979020631321
  8903915521695998471   430297716918194926], shape=(8,), dtype=uint64)
tf.Tensor(
[0.33333333 0.33333333 0.33333333 0.33333333 0.33333333 0.33333333
 0.33333333 0.33333333], shape=(8,), dtype=float64)


In [54]:
next(dataset_iterator)


(Trajectory(step_type=<tf.Tensor: shape=(8, 2), dtype=int32, numpy=
 array([[1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1]], dtype=int32)>, observation=<tf.Tensor: shape=(8, 2, 3), dtype=float32, numpy=
 array([[[ 0.9800823 ,  0.1985919 ,  0.41112426],
         [ 0.97654444,  0.21531594,  0.34188697]],
 
        [[ 0.9626323 ,  0.27081192,  0.09548815],
         [ 0.96210617,  0.27267516,  0.03872208]],
 
        [[ 0.9609107 , -0.27685857,  0.08212221],
         [ 0.96312565, -0.2690521 ,  0.16229314]],
 
        [[ 0.9641903 ,  0.26521143, -0.11068436],
         [ 0.96638525,  0.25709826, -0.16809733]],
 
        [[ 0.96336776, -0.2681838 , -0.20709708],
         [ 0.96117985, -0.27592254, -0.16084154]],
 
        [[ 0.9991272 , -0.04177071,  0.5775303 ],
         [ 0.9998886 , -0.01492673,  0.53711176]],
 
        [[ 0.96020067,  0.279311  ,  0.20772548],
         [ 0.9572971 ,  0.28910595,  0.20432597]],
 
    