In [1]:
import sys
import os

path = os.path.dirname(os.path.abspath("__file__"))
sys.path.insert(0, path + '/../')

import tensorflow as tf
import reverb
import tf_agents

from tf_agents.drivers import dynamic_step_driver, py_driver
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.policies import py_tf_eager_policy

from tf_agents.environments import suite_gym, parallel_py_environment
from tf_agents.environments import tf_py_environment, batched_py_environment
from reinforcement_learning import labeling_functions
labeling_function = labeling_functions['Pendulum-v0']



In [2]:
import policies
from util.io.dataset_generator import ErgodicMDPTransitionGenerator

env_name = 'Pendulum-v0'
py_env = suite_gym.load(env_name)
current_time_step = py_env.reset()

In [3]:
policy_dir = '../reinforcement_learning/saves/Pendulum-v0/policy'
policy = policies.SavedTFPolicy(policy_dir)
# transforms policy into a py policy because reverb replay directly interacts with py_env (and not a tf_env)
py_policy = py_tf_eager_policy.PyTFEagerPolicy(policy, use_tf_function=True)

In [4]:
table_name = 'prioritized_replay_buffer'
checkpoint_path = '../saves/checkpoint/reverb_test'

table = reverb.Table(
    table_name,
    max_size=int(1e6),
    sampler=reverb.selectors.Prioritized(priority_exponent=0.9),
    remover=reverb.selectors.MaxHeap(),
    rate_limiter=reverb.rate_limiters.MinSize(1))

checkpointer = reverb.checkpointers.DefaultCheckpointer(checkpoint_path)

reverb_server = reverb.Server([table], checkpointer=checkpointer)

In [5]:
reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
    data_spec=policy.collect_data_spec,
    sequence_length=2,
    table_name=table_name,
    local_server=reverb_server)

reverb_replay.get_table_info()

TableInfo(name='prioritized_replay_buffer', sampler_options=prioritized {
  priority_exponent: 0.9
}
, remover_options=heap {
}
is_deterministic: true
, max_size=1000000, max_times_sampled=0, rate_limiter_info=samples_per_insert: 1.0
min_diff: -1.7976931348623157e+308
max_diff: 1.7976931348623157e+308
min_size_to_sample: 1
insert_stats {
  completed_wait_time {
  }
  pending_wait_time {
  }
}
sample_stats {
  completed_wait_time {
  }
  pending_wait_time {
  }
}
, signature=None, current_size=12861, num_episodes=1, num_deleted_episodes=0)

In [6]:
batch_size = 512

dataset = reverb_replay.as_dataset(
    sample_batch_size=batch_size,
    num_steps=2,
    num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset_iterator = iter(dataset)

In [7]:
observer = reverb_utils.ReverbTrajectorySequenceObserver(
    py_client=reverb_replay.py_client,
    table_name=table_name,
    sequence_length=2,
    stride_length=1,
    priority=tf.constant(1., dtype=tf.float32))

In [8]:
num_steps = 12800
driver = py_driver.PyDriver(
    env=py_env, policy=py_policy, observers=[observer], max_steps=num_steps,)

In [9]:
driver.run(time_step=py_env.current_time_step())
reverb_replay.py_client.checkpoint()
reverb_replay.get_table_info()

TableInfo(name='prioritized_replay_buffer', sampler_options=prioritized {
  priority_exponent: 0.9
}
, remover_options=heap {
}
is_deterministic: true
, max_size=1000000, max_times_sampled=0, rate_limiter_info=samples_per_insert: 1.0
min_diff: -1.7976931348623157e+308
max_diff: 1.7976931348623157e+308
min_size_to_sample: 1
insert_stats {
  completed: 12861
  completed_wait_time {
  }
  pending_wait_time {
  }
}
sample_stats {
  completed_wait_time {
  }
  pending_wait_time {
  }
}
, signature=None, current_size=25722, num_episodes=2, num_deleted_episodes=0)

In [10]:
traj = next(dataset_iterator)
traj

(Trajectory(step_type=<tf.Tensor: shape=(512, 2), dtype=int32, numpy=
 array([[1, 1],
        [1, 1],
        [1, 1],
        ...,
        [1, 1],
        [1, 1],
        [1, 1]], dtype=int32)>, observation=<tf.Tensor: shape=(512, 2, 3), dtype=float32, numpy=
 array([[[ 0.98659474,  0.16318965,  0.20842017],
         [ 0.985931  ,  0.16715291,  0.08036919]],
 
        [[ 0.8664056 , -0.4993409 , -0.90088177],
         [ 0.8405068 , -0.54180104, -0.9948116 ]],
 
        [[ 0.9846218 , -0.17469929,  0.6248922 ],
         [ 0.9896979 , -0.14317156,  0.6387018 ]],
 
        ...,
 
        [[ 0.9022289 , -0.43125743,  0.9433839 ],
         [ 0.91784   , -0.39695054,  0.7538798 ]],
 
        [[-0.61261433, -0.790382  , -5.7078476 ],
         [-0.8355573 , -0.54940337, -6.5956454 ]],
 
        [[ 0.9934008 , -0.11469464,  0.55353975],
         [ 0.9961304 , -0.0878878 ,  0.5389254 ]]], dtype=float32)>, action=<tf.Tensor: shape=(512, 2, 1), dtype=float32, numpy=
 array([[[-1.6696215 ],
       

In [11]:
reverb_replay.update_priorities(traj[1].key[..., 0], 1./3 * tf.ones(shape=(batch_size,), dtype=tf.float64))

In [12]:
print(reverb_replay.get_table_info().current_size)
reverb_replay.py_client.checkpoint()

25722


'../saves/checkpoint/reverb_test/2021-04-28T11:54:50.287872093+02:00'

In [13]:
import variational_action_discretizer

vae_mdp = variational_action_discretizer.load(
    '../saves/Pendulum-v0/models/vae_LS12_MC1_ER10.0-decay=7.5e-05-min=0_KLA0.0-growth=5e-05_TD0.67-0.50_1e-06-2e-06_seed=20421/policy/action_discretizer/LA3_MC1_ER10.0-decay=7.5e-05-min=0_KLA0.0-growth=5e-05_TD0.50-0.33_1e-06-2e-06_params=full_vae_optimization-relaxed_state_encoding/base',
    step=1830000
)
print("VAE MDP loaded")



VAE MDP loaded


In [14]:

generator = ErgodicMDPTransitionGenerator(
    labeling_function=labeling_function,
    replay_buffer=reverb_replay,
    discrete_action=False,
    prioritized_replay_buffer=True,
    state_embedding_function=lambda state, label: tf.squeeze(
        vae_mdp.binary_encode(tf.expand_dims(state, axis=0), tf.expand_dims(label, axis=0)).mode()))

In [15]:
batch_size = 128

dataset = reverb_replay.as_dataset(
    num_steps=2,
    num_parallel_calls=tf.data.experimental.AUTOTUNE
).map(
    map_func=generator,
    num_parallel_calls=tf.data.experimental.AUTOTUNE)

dataset_iterator = iter(dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE))

In [16]:
next(dataset_iterator)

(<tf.Tensor: shape=(128, 3), dtype=float32, numpy=
 array([[ 9.63441133e-01,  2.67920107e-01,  1.07273854e-01],
        [-9.76647854e-01,  2.14846402e-01, -4.34838057e+00],
        [ 9.99942243e-01, -1.07456772e-02, -6.04829013e-01],
        [ 8.82993281e-01,  4.69385654e-01,  3.38135809e-01],
        [ 9.87567544e-01,  1.57195255e-01,  1.46168485e-01],
        [ 9.95417118e-01,  9.56284255e-02,  3.89924645e-01],
        [ 9.86738920e-01,  1.62315309e-01, -1.18274689e-01],
        [ 9.63504910e-01, -2.67690748e-01, -1.88405469e-01],
        [ 9.78107989e-01, -2.08098054e-01, -2.57059872e-01],
        [ 9.99998152e-01, -1.92165421e-03,  5.11418939e-01],
        [ 9.59214389e-01,  2.82679528e-01, -8.56893435e-02],
        [ 9.79098737e-01, -2.03385442e-01, -3.05926293e-01],
        [ 9.99978304e-01, -6.58385176e-03, -5.47296047e-01],
        [ 8.46369565e-01,  5.32596052e-01, -1.33668363e+00],
        [-5.83769269e-02,  9.98294592e-01, -1.75240791e+00],
        [ 9.78670835e-01, -2.05434

In [17]:
dataset.take(1)

<TakeDataset shapes: ((3,), (None,), (1,), <unknown>, (3,), (None,), ()), types: (tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float32, tf.float64)>

In [18]:
state_embedding_function = lambda state, label: tf.squeeze(
        vae_mdp.binary_encode(tf.expand_dims(state, axis=0), tf.expand_dims(label, axis=0)).mode())
generator.latent_state_size

<tf.Tensor: shape=(), dtype=int32, numpy=12>

In [19]:
tf_agents.system.multiprocessing.enable_interactive_mode()
py_env = parallel_py_environment.ParallelPyEnvironment(
    [lambda: suite_gym.load(env_name)] * 5)
py_env.reset()

TimeStep(step_type=array([0, 0, 0, 0, 0], dtype=int32), reward=array([0., 0., 0., 0., 0.], dtype=float32), discount=array([1., 1., 1., 1., 1.], dtype=float32), observation=array([[-0.93817997,  0.34614787, -0.69929904],
       [ 0.76385903,  0.6453832 ,  0.22561696],
       [ 0.45620412,  0.8898752 ,  0.71726984],
       [ 0.8449875 , -0.534786  ,  0.66297346],
       [-0.71191454, -0.7022661 , -0.06948522]], dtype=float32))

In [20]:
list(reverb_replay.py_client.sample(table_name, num_samples=8))

[[ReplaySample(info=SampleInfo(key=14253199262665925321, probability=3.9873669497706225e-05, table_size=25722, priority=1.0), data=[array(1, dtype=int32), array([ 0.97044015, -0.241342  ,  0.56521183], dtype=float32), array([1.9773306], dtype=float32), array(1, dtype=int32), array(-0.09526961, dtype=float32), array(1., dtype=float32)]),
  ReplaySample(info=SampleInfo(key=14253199262665925321, probability=3.9873669497706225e-05, table_size=25722, priority=1.0), data=[array(1, dtype=int32), array([ 0.9780917 , -0.20817454,  0.6808049 ], dtype=float32), array([1.1514839], dtype=float32), array(1, dtype=int32), array(-0.091653, dtype=float32), array(1., dtype=float32)])],
 [ReplaySample(info=SampleInfo(key=6206178101287334726, probability=3.9873669497706225e-05, table_size=25722, priority=1.0), data=[array(1, dtype=int32), array([ 0.97759354, -0.21050137,  0.40754882], dtype=float32), array([1.925373], dtype=float32), array(1, dtype=int32), array(-0.06529789, dtype=float32), array(1., dtyp

In [21]:
reverb_replay.get_table_info()

TableInfo(name='prioritized_replay_buffer', sampler_options=prioritized {
  priority_exponent: 0.9
}
, remover_options=heap {
}
is_deterministic: true
, max_size=1000000, max_times_sampled=0, rate_limiter_info=samples_per_insert: 1.0
min_diff: -1.7976931348623157e+308
max_diff: 1.7976931348623157e+308
min_size_to_sample: 1
insert_stats {
  completed: 12861
  completed_wait_time {
  }
  pending_wait_time {
  }
}
sample_stats {
  completed: 2696
  completed_wait_time {
  }
  pending_wait_time {
  }
}
, signature=None, current_size=25722, num_episodes=2, num_deleted_episodes=0)

In [22]:
observer.close()

In [23]:
reverb_server.stop()