In [1]:
import tensorflow as tf
import reverb
import tf_agents

from tf_agents.drivers import dynamic_step_driver, py_driver
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.policies import py_tf_eager_policy

from tf_agents.environments import suite_gym, parallel_py_environment
from tf_agents.environments import tf_py_environment, batched_py_environment
from reinforcement_learning import labeling_functions
labeling_function = labeling_functions['Pendulum-v0']



In [2]:
import policies
from util.io.dataset_generator import ErgodicMDPTransitionGenerator

env_name = 'Pendulum-v0'
py_env = suite_gym.load(env_name)
current_time_step = py_env.reset()

In [3]:
policy_dir = '../reinforcement_learning/saves/Pendulum-v0/policy'
policy = policies.SavedTFPolicy(policy_dir)
# transforms policy into a py policy because reverb replay directly interacts with py_env (and not a tf_env)
py_policy = py_tf_eager_policy.PyTFEagerPolicy(policy, use_tf_function=True)

In [4]:
table_name = 'prioritized_replay_buffer'
checkpoint_path = '../saves/checkpoint/reverb_test'

table = reverb.Table(
    table_name,
    max_size=int(1e6),
    sampler=reverb.selectors.Prioritized(priority_exponent=0.9),
    remover=reverb.selectors.MaxHeap(),
    rate_limiter=reverb.rate_limiters.MinSize(1))

checkpointer = reverb.checkpointers.DefaultCheckpointer(checkpoint_path)

reverb_server = reverb.Server([table], checkpointer=checkpointer)

In [5]:
reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
    data_spec=policy.collect_data_spec,
    sequence_length=2,
    table_name=table_name,
    local_server=reverb_server)

reverb_replay.get_table_info()

TableInfo(name='prioritized_replay_buffer', sampler_options=prioritized {
  priority_exponent: 0.9
}
, remover_options=heap {
}
is_deterministic: true
, max_size=1000000, max_times_sampled=0, rate_limiter_info=samples_per_insert: 1.0
min_diff: -1.7976931348623157e+308
max_diff: 1.7976931348623157e+308
min_size_to_sample: 1
insert_stats {
  completed_wait_time {
  }
  pending_wait_time {
  }
}
sample_stats {
  completed_wait_time {
  }
  pending_wait_time {
  }
}
, signature=None, current_size=12861, num_episodes=1, num_deleted_episodes=0)

In [6]:
batch_size = 512

dataset = reverb_replay.as_dataset(
    sample_batch_size=batch_size,
    num_steps=2,
    num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset_iterator = iter(dataset)

In [7]:
observer = reverb_utils.ReverbTrajectorySequenceObserver(
    py_client=reverb_replay.py_client,
    table_name=table_name,
    sequence_length=2,
    stride_length=1,
    priority=tf.constant(1., dtype=tf.float32))

In [8]:
num_steps = 12800
driver = py_driver.PyDriver(
    env=py_env, policy=py_policy, observers=[observer], max_steps=num_steps,)

In [9]:
driver.run(time_step=py_env.current_time_step())
reverb_replay.py_client.checkpoint()
reverb_replay.get_table_info()

TableInfo(name='prioritized_replay_buffer', sampler_options=prioritized {
  priority_exponent: 0.9
}
, remover_options=heap {
}
is_deterministic: true
, max_size=1000000, max_times_sampled=0, rate_limiter_info=samples_per_insert: 1.0
min_diff: -1.7976931348623157e+308
max_diff: 1.7976931348623157e+308
min_size_to_sample: 1
insert_stats {
  completed: 12861
  completed_wait_time {
  }
  pending_wait_time {
  }
}
sample_stats {
  completed_wait_time {
  }
  pending_wait_time {
  }
}
, signature=None, current_size=25722, num_episodes=2, num_deleted_episodes=0)

In [10]:
traj = next(dataset_iterator)
traj

(Trajectory(step_type=<tf.Tensor: shape=(512, 2), dtype=int32, numpy=
 array([[1, 1],
        [1, 1],
        [1, 1],
        ...,
        [1, 1],
        [1, 1],
        [1, 1]], dtype=int32)>, observation=<tf.Tensor: shape=(512, 2, 3), dtype=float32, numpy=
 array([[[ 0.92839897,  0.3715849 , -0.5861269 ],
         [ 0.9384471 ,  0.34542295, -0.56052274]],
 
        [[ 0.9661661 , -0.25792068, -0.20527856],
         [ 0.96421725, -0.26511332, -0.14904003]],
 
        [[ 0.9673638 ,  0.25339165,  0.2960781 ],
         [ 0.96432483,  0.26472184,  0.23461469]],
 
        ...,
 
        [[ 0.98898166,  0.14803825, -0.17979282],
         [ 0.99068755,  0.13615498, -0.24010344]],
 
        [[ 0.9598526 , -0.28050497, -0.15742743],
         [ 0.9586249 , -0.2846722 , -0.08688591]],
 
        [[ 0.9943444 , -0.10620368,  0.5469482 ],
         [ 0.99681413, -0.07975936,  0.5312037 ]]], dtype=float32)>, action=<tf.Tensor: shape=(512, 2, 1), dtype=float32, numpy=
 array([[[-1.6872299],
        

In [11]:
reverb_replay.update_priorities(traj[1].key[..., 0], 1./3 * tf.ones(shape=(batch_size,), dtype=tf.float64))

In [12]:
print(reverb_replay.get_table_info().current_size)
reverb_replay.py_client.checkpoint()

25722


'../saves/checkpoint/reverb_test/2021-04-27T15:12:28.644602286+02:00'

In [13]:
import variational_action_discretizer

vae_mdp = variational_action_discretizer.load(
    '../saves/Pendulum-v0/models/vae_LS12_MC1_ER10.0-decay=7.5e-05-min=0_KLA0.0-growth=5e-05_TD0.67-0.50_1e-06-2e-06_seed=20421/policy/action_discretizer/LA3_MC1_ER10.0-decay=7.5e-05-min=0_KLA0.0-growth=5e-05_TD0.50-0.33_1e-06-2e-06_params=full_vae_optimization-relaxed_state_encoding/base',
    step=1830000
)
print("VAE MDP loaded")



VAE MDP loaded


In [14]:

generator = ErgodicMDPTransitionGenerator(
    labeling_function=labeling_function,
    replay_buffer=reverb_replay,
    discrete_action=False,
    prioritized_replay_buffer=True,
    state_embedding_function=lambda state, label: tf.squeeze(
        vae_mdp.binary_encode(tf.expand_dims(state, axis=0), tf.expand_dims(label, axis=0)).mode()))

In [15]:
batch_size = 128

dataset = reverb_replay.as_dataset(
    num_steps=2,
    num_parallel_calls=tf.data.experimental.AUTOTUNE
).map(
    map_func=generator,
    num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset_iterator = iter(dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE))

In [16]:
next(dataset_iterator)

(<tf.Tensor: shape=(128, 3), dtype=float32, numpy=
 array([[ 9.92646575e-01,  1.21048681e-01, -8.24625324e-03],
        [ 9.96205270e-01,  8.70348364e-02,  2.72046685e-01],
        [ 9.74649310e-01, -2.23738000e-01,  4.25707012e-01],
        [ 9.67182279e-01, -2.54083574e-01,  5.03282428e-01],
        [ 9.97142434e-01,  7.55448490e-02,  5.45736372e-01],
        [-9.80952144e-01,  1.94249406e-01, -3.56258583e+00],
        [ 9.94035065e-01,  1.09061211e-01,  1.06572703e-01],
        [ 9.71203744e-01, -2.38250345e-01, -2.38811299e-01],
        [ 9.94468689e-01,  1.05033465e-01, -5.63394785e-01],
        [ 9.92794693e-01,  1.19827785e-01, -2.11463496e-01],
        [ 9.99970078e-01,  7.73215480e-03,  3.97090256e-01],
        [ 9.77079988e-01,  2.12872475e-01,  3.44169647e-01],
        [ 9.99644279e-01, -2.66695470e-02, -1.95535630e-01],
        [ 9.99227643e-01,  3.92950773e-02, -4.79143351e-01],
        [ 9.75390315e-01,  2.20485285e-01, -4.18104172e-01],
        [ 9.89747286e-01,  1.42829

In [17]:
dataset.take(1)

<TakeDataset shapes: ((3,), (None,), (1,), <unknown>, (3,), (None,), (), ()), types: (tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float64, tf.int64)>

In [18]:
state_embedding_function = lambda state, label: tf.squeeze(
        vae_mdp.binary_encode(tf.expand_dims(state, axis=0), tf.expand_dims(label, axis=0)).mode())
generator.latent_state_size

<tf.Tensor: shape=(), dtype=int32, numpy=12>

In [19]:
tf_agents.system.multiprocessing.enable_interactive_mode()
py_env = parallel_py_environment.ParallelPyEnvironment(
    [lambda: suite_gym.load(env_name)] * 5)
py_env.reset()

TimeStep(step_type=array([0, 0, 0, 0, 0], dtype=int32), reward=array([0., 0., 0., 0., 0.], dtype=float32), discount=array([1., 1., 1., 1., 1.], dtype=float32), observation=array([[-0.9143689 ,  0.4048821 ,  0.7941084 ],
       [ 0.7959196 ,  0.6054023 ,  0.30485106],
       [-0.92786896,  0.37290642,  0.5644189 ],
       [-0.43345684, -0.9011743 ,  0.30105647],
       [ 0.42689294, -0.9043022 , -0.6943645 ]], dtype=float32))

In [34]:
list(reverb_replay.py_client.sample(table_name, num_samples=8))

[[ReplaySample(info=SampleInfo(key=12821263594570176645, probability=8.186512466623928e-05, table_size=12861, priority=1.0), data=[array(1, dtype=int32), array([ 0.9588191 , -0.2840176 ,  0.04937043], dtype=float32), array([1.7314401], dtype=float32), array(1, dtype=int32), array(-0.08617506, dtype=float32), array(1., dtype=float32)]),
  ReplaySample(info=SampleInfo(key=12821263594570176645, probability=8.186512466623928e-05, table_size=12861, priority=1.0), data=[array(1, dtype=int32), array([ 0.96017236, -0.27940848,  0.09607325], dtype=float32), array([1.9878347], dtype=float32), array(1, dtype=int32), array(-0.08506427, dtype=float32), array(1., dtype=float32)])],
 [ReplaySample(info=SampleInfo(key=13449431264708799662, probability=8.186512466623928e-05, table_size=12861, priority=1.0), data=[array(1, dtype=int32), array([0.96067256, 0.27768373, 0.18673047], dtype=float32), array([-1.7925342], dtype=float32), array(1, dtype=int32), array(-0.08587594, dtype=float32), array(1., dtype

In [35]:
reverb_replay.get_table_info()

TableInfo(name='prioritized_replay_buffer', sampler_options=prioritized {
  priority_exponent: 0.9
}
, remover_options=heap {
}
is_deterministic: true
, max_size=1000000, max_times_sampled=0, rate_limiter_info=samples_per_insert: 1.0
min_diff: -1.7976931348623157e+308
max_diff: 1.7976931348623157e+308
min_size_to_sample: 1
insert_stats {
  completed: 12861
  completed_wait_time {
  }
  pending_wait_time {
  }
}
sample_stats {
  completed: 3280
  completed_wait_time {
  }
  pending_wait_time {
  }
}
, signature=None, current_size=12861, num_episodes=1, num_deleted_episodes=0)

In [None]:
observer.close()

In [6]:
reverb_server.stop()