In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a href="https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_reverb_patterns.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Run In Google Colab"/></a>
  </td>
</table>

# RLDS & Reverb patterns

This Colab showcases how to transform an RLDS dataset using Reverb patterns by:

1. Apply a pattern directly in the dataset
1. Apply the same pattern when inserting the steps into a Reverb table.

Note that the same patterns can be applied to data collected online (for example, from running an environment), making it easier to mix online and offline experience.

This colab focuses on the comparison between how to apply the patterns to RLDS and Reverb. If you are looking for details about the patterns API,  see [this colab](https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_dataset_patterns.ipynb).

In [None]:
#@title Install Pip packages
!pip install rlds[tensorflow]
!pip install envlogger[tfds]
!apt-get install libgmp-dev
!pip install numpy

In [1]:
#@title Imports
import os
import rlds
import envlogger
from envlogger.backends import rlds_utils
from envlogger.backends import tfds_backend_writer
from envlogger.testing import catch_env
import numpy as np
import reverb
import tensorflow as tf
import tensorflow_datasets as tfds
import time
from typing import Optional, List

## Generate a dataset to use throughout the examples

Before experimenting with the patterns, we generate a test dataset using random actions in a Catch environment.

In [2]:
dataset_path = '/tmp/tensorflow_datasets/catch/'  # @param
num_episodes = 20  # @param
max_episodes_per_shard = 1000  # @param

In [3]:
#@title Record Data Utils
def record_data(data_dir, num_episodes, max_episodes_per_shard):
  env = catch_env.Catch()

  def step_fn(unused_timestep, unused_action, unused_env):
    return {'timestamp_ns': time.time_ns()}

  ds_config = tfds.rlds.rlds_base.DatasetConfig(
      name='catch_example',
      observation_info=tfds.features.Tensor(
          shape=(10, 5), dtype=tf.float32,
          encoding=tfds.features.Encoding.ZLIB),
      action_info=tf.int64,
      reward_info=tf.float64,
      discount_info=tf.float64,
      step_metadata_info={'timestamp_ns': tf.int64})

  with envlogger.EnvLogger(
      env,
      backend=tfds_backend_writer.TFDSBackendWriter(
          data_directory=data_dir,
          split_name='train',
          max_episodes_per_file=max_episodes_per_shard,
          ds_config=ds_config),
      step_fn=step_fn) as env:
    print('Done wrapping environment with EnvironmentLogger.')

    print(f'Training a random agent for {num_episodes} episodes...')
    for i in range(num_episodes):
      print(f'episode {i}')
      timestep = env.reset()
      while not timestep.last():
        action = np.random.randint(low=0, high=3)
        timestep = env.step(action)
    print(f'Done training a random agent for {num_episodes} episodes.')

In [4]:
os.makedirs(dataset_path, exist_ok=True)
record_data(dataset_path, num_episodes, max_episodes_per_shard)

Done wrapping environment with EnvironmentLogger.
Training a random agent for 20 episodes...
episode 0
episode 1
episode 2
episode 3
episode 4
episode 5
episode 6
episode 7
episode 8
episode 9
episode 10
episode 11
episode 12
episode 13
episode 14
episode 15
episode 16
episode 17
episode 18
episode 19
Done training a random agent for 20 episodes.


In [5]:
loaded_dataset = tfds.builder_from_directory(dataset_path).as_dataset(
    split='all')

## Launch a Reverb server

We launch a reverb server to use later.

In [6]:
simple_server = reverb.Server(
    tables=[
        reverb.Table(
            name='transition',
            sampler=reverb.selectors.Uniform(),
            remover=reverb.selectors.Fifo(),
            max_size=int(1e6),
            # Sets Rate Limiter to a low number for the examples.
            # Read the Rate Limiters section for usage info.
            rate_limiter=reverb.rate_limiters.MinSize(2),
            # Note that instead of defining the signature explicitly, we can use 
            # reverb.structured_writer.infer_signature 
            # to get the table signature from the Reverb pattern.
            signature={
                'action': tf.TensorSpec([], tf.int64),
                'observation': tf.TensorSpec([10, 5], tf.float32),
                'next_observation': tf.TensorSpec([10, 5], tf.float32),
            },
        )
    ])

# Initializes the reverb client on the same port as the server.
client = reverb.Client(f'localhost:{simple_server.port}')

[1m[32m[Reverb] Live dashboard: http://sabela.c.googlers.com:24800[0m


# Reverb Pattern

We define a pattern that takes one RLDS step and returns a transition

In [7]:
def transition(step):
 return {
    rlds.OBSERVATION: step[rlds.OBSERVATION][-2],
    rlds.ACTION: step[rlds.ACTION][-2],
    'next_observation': step[rlds.OBSERVATION][-1],
  }

step_spec = rlds.transformations.step_spec(loaded_dataset)
pattern = reverb.structured_writer.pattern_from_transform(step_spec, transition)
config = reverb.structured_writer.create_config(pattern, table='transition')

# Transform dataset with the Reverb Pattern

By using the `PatternDataset` or the `rlds.transformations`, it is possible to transform a stream of steps into a stream of trajectories with Reverb Patterns.

If instead of one pattern you need to apply a list of patterns, check `rlds.transformations.pattern_map`.

In [8]:
pattern_dataset = rlds.transformations.pattern_map_from_transform(
    episodes_dataset=loaded_dataset,
    transform_fn=transition,
    respect_episode_boundaries=True
)

for transition in pattern_dataset.take(2):
  print(transition)

{'action': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'next_observation': <tf.Tensor: shape=(10, 5), dtype=float32, numpy=
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.]], dtype=float32)>, 'observation': <tf.Tensor: shape=(10, 5), dtype=float32, numpy=
array([[0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.]], dtype=float32)>}
{'action': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'next_observation': <tf.Tensor: shape=(10, 5), dtype=float32, numpy=
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 

# Convert to trajectories by inserting in Reverb

When inserting on a Reverb table, we can use the same patterns we used to transform a dataset. In this example, we apply them to an RLDS dataset, but they could also be applied to online data collected directly from an environment.

## Insert
In a real case, we would have two processes, one producing data and inserting it in Reverb, and another one sampling from Reverb. For the sake of simplicity, we insert first and consider that all data comes from the dataset we created before.

In [9]:
reverb_writer = client.structured_writer(configs=[config])

for step in loaded_dataset.flat_map(lambda e: e[rlds.STEPS]):
  reverb_writer.append(data=step)
  if step[rlds.IS_LAST]:
    reverb_writer.end_episode(clear_buffers=True)
  reverb_writer.flush

## Sample

The data sampled from the Reverb server has the same shape as the data we obtained when applying the pattern directly to the dataset.

In [10]:
reverb_dataset = reverb.TrajectoryDataset.from_table_signature(server_address=f'localhost:{simple_server.port}', table='transition', max_in_flight_samples_per_worker=1)

for sample in reverb_dataset.take(2):
  print(sample)

ReplaySample(info=SampleInfo(key=<tf.Tensor: shape=(), dtype=uint64, numpy=8567135249577419563>, probability=<tf.Tensor: shape=(), dtype=float64, numpy=0.005555555555555556>, table_size=<tf.Tensor: shape=(), dtype=int64, numpy=180>, priority=<tf.Tensor: shape=(), dtype=float64, numpy=1.0>, times_sampled=<tf.Tensor: shape=(), dtype=int32, numpy=1>), data={'action': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'next_observation': <tf.Tensor: shape=(10, 5), dtype=float32, numpy=
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0.]], dtype=float32)>, 'observation': <tf.Tensor: shape=(10, 5), dtype=float32, numpy=
array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
  