This colab serves as a demonstration on how to use Reverb and its more advanced features (samplers, removers, rate limiters). It will also present Reverb's data model, so users can understand what is going on under the hood.


In [0]:
import reverb

import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

Let's define a dummy RL environment for the examples below.

In [0]:
observations_shape = tf.TensorShape([10, 10])
actions_shape = tf.TensorShape([2])

def agent_step(unused_timestep) -> tf.Tensor:
  return tf.cast(tf.random.uniform(actions_shape) > .5, tf.float32)

def environment_step(unused_action) -> tf.Tensor:
  return tf.cast(tf.random.uniform(observations_shape, maxval=256), tf.uint8)

# Creating a Server and Client

In [0]:
# Initialize the reverb server.
simple_server = reverb.Server(
    priority_tables=[
        reverb.PriorityTable(
            name='my_table',
            sampler=reverb.distributions.Prioritized(priority_exponent=0.8),
            remover=reverb.distributions.Fifo(),
            max_size=int(1e6),
            # We need to configure the rate limiter to a low number for this 
            # example. See Rate Limiter session below for details.
            rate_limiter=reverb.rate_limiters.MinSize(2)),
    ],
    # Setting port to None will make the server pick one automatically.
    port=None
)

# Initialize the reverb client on the same port.
client = reverb.Client(f'localhost:{simple_server.port}')


For more details on customizing the sampler, remover, and rate limiter, see below.

# Example 1: Overlapping Trajectories


## Inserting Overlapping Trajectories

In [0]:
# Dynamically add trajectories of length 3 to 'my_table' using a client writer.

with client.writer(max_sequence_length=3) as writer:
  timestep = environment_step(None)
  for step in range(4):
    action = agent_step(timestep)
    writer.append_timestep((timestep, action))
    timestep = environment_step(action)
    if step >= 2:
      # In this example, the item consists of the 3 most recent timesteps that
      # were added to the writer and has a priority of 1.5.
      writer.create_prioritized_item(
          table='my_table', num_timesteps=3, priority=1.5)

This animation shows what the state of the server looks like at each step in the above code block. Although we are manually setting each item to have the same priority value of 1.5, items do not need to have the same priority values (and, in reality, will likely have differing and dynamically-calculated priority values).



<img src="https://github.com/deepmind/reverb/blob/master/docs/animations/diagram1.svg" />

## Sampling Overlapping Trajectories in TensorFlow

In [0]:
# Create a TensorFlow client.
tf_client = reverb.TFClient(f'localhost:{simple_server.port}')
# Set the sequence length to match the length of the prioritized items
# inserted into the table (to match the example above, we use 3 here).
sequence_length = 3

# Dataset will sample sequences of length 3 and then stream the timesteps one 
# by one. This allows you to stream large sequences that do not necessarily fit 
# in memory.
dataset = tf_client.dataset(
  table='my_table',
  dtypes=(tf.uint8, tf.float32),
  shapes=(observations_shape, actions_shape))

# First batch the data according to the correct sequence length.
# Shape of items in this example are now [3, 10, 10].
dataset = dataset.batch(sequence_length)

In [0]:
# Batch 2 sequences together.
# Shapes of items are now [2, 3, 10, 10].
dataset = dataset.batch(2)

for sample in dataset.take(1):
  # Results in the following format.
  print(sample.info.key)          # ([2, 3], uint64)
  print(sample.info.probability)  # ([2, 3], float64)
  
  observation, action = sample.data
  print(observation)              # ([2, 3, 10, 10], int8)
  print(action)                   # ([2, 3, 2], float32)

# Example 2: Complete Episodes

We will create a new server for the purposes of this example to keep the elements of the priority table consistent.

In [0]:
complete_episode_server = reverb.Server(
    priority_tables=[
        reverb.PriorityTable(
            name='my_table',
            sampler=reverb.distributions.Prioritized(priority_exponent=0.8),
            remover=reverb.distributions.Fifo(),
            max_size=int(1e6),
            # We need to configure the rate limiter to a low number for this 
            # example. See Rate Limiter session below for details.
            rate_limiter=reverb.rate_limiters.MinSize(2)),
    ],
    # Setting port to None will make the server pick one automatically.
    port=None
)

# Initialize the reverb client on the same port.
client = reverb.Client(f'localhost:{complete_episode_server.port}')

## Inserting Complete Episodes

In [0]:
# Add a episodes as a single entry to 'my_table' using the insert function.
episode_length = 100

num_episodes = 200

def unroll_full_episode():
  observations, actions = [environment_step(None)], []
  for _ in range(1, episode_length):
    actions.append(agent_step(observations[-1]))
    observations.append(environment_step(actions[-1]))
  return tf.stack(observations), tf.stack(actions)

for _ in range(num_episodes):
  # Use client.insert when dealing with full trajectories, not individual 
  # timesteps.
  client.insert(unroll_full_episode(), {'my_table': 1.5})

## Sampling Complete Episodes in TensorFlow

In [0]:
# Create a TensorFlow client.
tf_client = reverb.TFClient(f'localhost:{complete_episode_server.port}')

# Each sample is an entire episode.
# We need to adjust the expected shapes to account for the whole episode length.
dataset = tf_client.dataset(
  table='my_table',
  dtypes=(tf.int8, tf.float32),
  shapes=([episode_length] + observations_shape, 
          [episode_length - 1] + actions_shape))

# Batch 128 episodes together.
# Each item is an episode of the format (observations, actions) as above.
# Shape of items in this example are ([128, 100, 10, 10], [128, 100, 2]).
dataset = dataset.batch(128)

# Sample has type reverb.ReplaySample.
for sample in dataset.take(1):
  # Results in the following format.
  print(sample.info.key)          # ([128], uint64)
  print(sample.info.probability)  # ([128], float64)
  
  observation, action = sample.data
  print(observation)              # ([128, 100, 10, 10], int8)
  print(action)                   # ([128, 100, 2], float32)

# Example 3: Multiple Priority Tables

First, we create a server that maintains multiple priority tables.

In [0]:
multitable_server = reverb.Server(
    priority_tables=[
        reverb.PriorityTable(
            name='my_table_a',
            sampler=reverb.distributions.Prioritized(priority_exponent=0.8),
            remover=reverb.distributions.Fifo(),
            max_size=int(1e6),
            # Set this to a low number for the examples here 
            # (see Rate Limiters section)
            rate_limiter=reverb.rate_limiters.MinSize(1)),
        reverb.PriorityTable(
            name='my_table_b',
            sampler=reverb.distributions.Prioritized(priority_exponent=0.8),
            remover=reverb.distributions.Fifo(),
            max_size=int(1e6),
            # Set this to a low number for the examples here 
            # (see Rate Limiters section)
            rate_limiter=reverb.rate_limiters.MinSize(1)),
    ],
    port=None
)

client = reverb.Client('localhost:{}'.format(multitable_server.port))


## Inserting Sequences of Varying Length into Multiple Priority Tables


In [0]:
with client.writer(max_sequence_length=3) as writer:
  timestep = environment_step(None)
  for step in range(4):
    writer.append_timestep(timestep)
    action = agent_step(timestep)
    timestep = environment_step(action)

    if step >= 1:
      writer.create_prioritized_item(
          table='my_table_b', num_timesteps=2, priority=4-step)
    if step >= 2:
      writer.create_prioritized_item(
          table='my_table_a', num_timesteps=3, priority=4-step)

<img src="https://github.com/deepmind/reverb/blob/master/docs/animations/diagram2.svg" />

The above diagram shows the state of the server after executing the overlapping trajectories code.

It is also possible to insert full trajectories into multiple tables using `client.insert` as in the example above.

```python
client.insert(episode, {'my_table_one': 1.5, 'my_table_two': 2.5})
```

# Example 4: Samplers and Removers



##  Creating a Server with a Prioritized Sampler and a FIFO Remover

In [0]:
reverb.Server(
    priority_tables=[
        reverb.PriorityTable(
            name='my_table',
            sampler=reverb.distributions.Prioritized(priority_exponent=0.8),
            remover=reverb.distributions.Fifo(),
            max_size=int(1e6),
            rate_limiter=reverb.rate_limiters.MinSize(100)),
    ],
    port=None
)

## Creating a Server with a MaxHeap Sampler and a MinHeap Remover

By setting `max_times_sampled=1`, each item is removed after it is sampled once, so this priority table essentially functions as a max priority queue.


In [0]:
reverb.Server(
    priority_tables=[
        reverb.PriorityTable(
            name='my_priority_queue',
            sampler=reverb.distributions.MaxHeap(),
            remover=reverb.distributions.MinHeap(),
            max_times_sampled=1,
            max_size=1000)
    ],
    port=None
)

## Creating a Server with One Queue and One Circular Buffer

Reverb can achieve the behavior of canonical data structures (such as a
[circular buffer](https://en.wikipedia.org/wiki/Circular_buffer) or a max
[priority queue](https://en.wikipedia.org/wiki/Priority_queue) as above) by
modifying the 'sampler' and 'remover' or by using the PriorityTable queue
initializer.

In [0]:
reverb.Server(
    priority_tables=[
        reverb.PriorityTable.queue(name='my_queue', max_size=10000),
        reverb.PriorityTable(
            name='my_circular_buffer',
            sampler=reverb.distributions.Fifo(),
            remover=reverb.distributions.Fifo(),
            max_size=10000,
            max_times_sampled=1,
            rate_limiter=reverb.rate_limiters.MinSize(1)),
    ],
    port=None
)

# Example 5: Rate Limiters


## Creating a Server with a SampleToInsertRatio Rate Limiter

In [0]:
reverb.Server(
    priority_tables=[
        reverb.PriorityTable(
            name='my_table',
            sampler=reverb.distributions.Prioritized(priority_exponent=0.8),
            remover=reverb.distributions.Fifo(),
            max_size=int(1e6),
            rate_limiter=reverb.rate_limiters.SampleToInsertRatio(
              samples_per_insert=3.0,
              min_size_to_sample=3,
              error_buffer=3.0)),
    ],
    port=None
)


Note that this example does not suffer from deadlocks when the system is distributed
(or multi-threaded) because insertion blocking will eventually be unblocked by
sample calls from an independent thread. If the system had only a single thread,
however, the blocked insertion call would cause a deadlock.
