In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a href="https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_dataset_patterns.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Run In Google Colab"/></a>
  </td>
</table>

# RLDS Dataset Patterns

This Colab showcases how to transform an RLDS dataset using Reverb patterns by applying a pattern directly to the dataset.

If you are looking for examples on how to apply the same pattern to an RLDS dataset and when using a Reverb table, see [this colab](https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_reverb_patterns.ipynb).

In [None]:
#@title Install Pip packages
!pip install rlds[tensorflow]
!pip install numpy

In [None]:
#@title Imports
import rlds
import reverb
import tensorflow as tf
import tensorflow_datasets as tfds
import tree

## Load a dataset to use throughout the examples

Before experimenting with the patterns, we load a dataset to use through our examples. To experiment with other datasets, take a look at the available datasets in the [TFDS catalog](https://www.tensorflow.org/datasets/catalog/overview) and look for those that are RLDS-compatible.


In [None]:
dataset_name = 'd4rl_mujoco_halfcheetah/v0-medium'  # @param
num_episodes = 20  # @param

dataset = tfds.load(dataset_name)['train'].take(num_episodes)

step_spec = rlds.transformations.step_spec(dataset)

# Reverb Patterns

Reverb patterns are an API to transform streams of steps. The result is another stream of elements that can be, for example, transformed steps, transitions or trajectories.

One pattern is actually a set of `configurations`, where each `configuration`
consists of:
* A transformation that specifies how the output is constructed. For example,
"take the observations from the last 5 steps".
* A set of conditions to decide whether to apply the pattern. For example, apply the pattern only to the first 10 steps of an episode.

Besides, users can decide what is the behaviour when the stream of steps reaches the end of an episode: create output elements that contain data from different episodes, or not (in this case, each output element is guaranteed to contain only elements from one episode).

**Note**: When constructing the pattern, we assume steps are being accumulated one by one in a queue, and we access the queue from the tail of the queue (see the examples below to understand exactly how it works). Everytime we insert a new step, we check all the conditions to decide if the pattern has to be applied.

## The Transformations

The transformation is the part of the pattern that specifies how the steps are transformed to produce the output elements.

The transformations express how the elements of the steps are grouped in order to construct the output elements. For example, we can construct SARS transitions by getting the observation of the last two steps, but the reward and action of only the previous-to-last step.

However, the transformations do not allow operations with the step values. For example, if we want to build N-step transitions, we have to construct a transition that contains the sequence of N-1 rewards and discounts, but we cannot calculate directly the discounted reward using the patterns.

### Example 1: using  a reference spec.

For example, the next transformation uses RLDS steps to produce SARS transitions.
 
It first uses the `create_reference_step` to create a `reference step` from the spec of the steps dataset. This `reference_step` is then used to describe how the steps elements are used to construct a transition.

**Note:** See how we access the elements from the end of the queue.

In [None]:
ref_step = reverb.structured_writer.create_reference_step(step_spec)

sars_pattern = {
    rlds.OBSERVATION:
        tree.map_structure(lambda x: x[-2], ref_step[rlds.OBSERVATION]),
    rlds.ACTION:
        ref_step[rlds.ACTION][-2],
    rlds.REWARD:
        ref_step[rlds.REWARD][-2],
    'next_observation':
        tree.map_structure(lambda x: x[-1], ref_step[rlds.OBSERVATION]),
}

### Example 2: using a transformation

Sometimes, it is more convenient to define a function with the transformation.

This can be achieved by using `pattern_from_transform` (note that this call is
not necessary if we use `rlds.transformations.pattern_map_from_transform`, see the [example](https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_dataset_patterns.ipynb#scrollTo=sO_vJobw3JVg) below.

In [None]:
def get_sars_pattern(step):
  return {
      rlds.OBSERVATION:
          tree.map_structure(lambda x: x[-2], step[rlds.OBSERVATION]),
      rlds.ACTION:
          step[rlds.ACTION][-2],
      rlds.REWARD:
          step[rlds.REWARD][-2],
      'next_observation':
          tree.map_structure(lambda x: x[-1], step[rlds.OBSERVATION]),
  }


another_sars_pattern = reverb.structured_writer.pattern_from_transform(
    step_spec, get_sars_pattern)

## The Conditions

The Patterns are formed by the transformation and a set of conditions to decide to which steps to apply the patterns.

You can find a list of conditions and utils to define them in [here](https://github.com/deepmind/reverb/blob/7e9b5693572c18e484ee57329ea4d2019501904e/reverb/structured_writer.py#L431).

By default, when creating a pattern, there is always an implicit condition: the number of steps in the buffer has to be enough to apply the transformation without causing Out-of-Bounds errors. For example, when constructing SARS transitions, we need at least two steps.

**Note**: If you are creating large trajectories without allowing to merge episodes, and your episodes are short (e.g., you create 10-steps trajectories and all your episodes have 5 steps), the dataset will not generate any output.

### Example 1: Condition over the number of steps

Use only the first 10 steps to produce a transition.

In [None]:
condition_first_steps = reverb.structured_writer.Condition.step_index() < 10

### Example 2: Condition over the data

It is also possible to apply a condition to data value, e.g., don't create transitions where the last observation belongs to a terminal state.

In [None]:
def condition_fn(step):
  return step[rlds.IS_TERMINAL] == False


condition_terminal = condition_fn(
    reverb.structured_writer.Condition.data(step_spec))

# Applying the Pattern

By using the `PatternDataset` we apply the pattern to construct sars trajectories using only the first 10 steps of each episode.

## The Simplest Case: One Pattern

In [None]:
pattern_dataset_rlds_simple = rlds.transformations.pattern_map_from_transform(
    episodes_dataset=dataset,
    transform_fn=get_sars_pattern,
    # By setting this to true, we don't generate transitions that mix steps
    # from two episodes.
    respect_episode_boundaries=True,
)

for transition in pattern_dataset_rlds_simple.take(2):
  print(transition)

{'action': <tf.Tensor: shape=(6,), dtype=float32, numpy=
array([ 0.98304164, -0.30126104,  0.3770365 ,  0.53884596, -0.94180214,
       -0.158314  ], dtype=float32)>, 'next_observation': <tf.Tensor: shape=(17,), dtype=float32, numpy=
array([-0.01038213, -0.02540122,  0.33964765, -0.1585834 ,  0.11570273,
        0.25722653, -0.3806618 , -0.05179007,  1.0116358 , -0.6148436 ,
       -0.09667274,  6.886969  , -5.4096785 ,  3.4446044 ,  7.0207024 ,
       -7.125424  , -0.6964876 ], dtype=float32)>, 'observation': <tf.Tensor: shape=(17,), dtype=float32, numpy=
array([ 0.00479665, -0.02969088,  0.06049511,  0.09577212, -0.01482561,
       -0.0183125 , -0.06952589, -0.02880755,  0.18119103,  0.00390713,
       -0.07221037, -0.06411345,  0.02970835, -0.12453649, -0.06031511,
       -0.02463249,  0.02966032], dtype=float32)>, 'reward': <tf.Tensor: shape=(), dtype=float32, numpy=0.6067187>}
{'action': <tf.Tensor: shape=(6,), dtype=float32, numpy=
array([ 0.9912193 , -0.8879203 , -0.97937745, -0

## Applying a list of Patterns

The examples below apply still one pattern, but the `configs` parameters in both cases take a list, so it's possible to apply more than one config.

### Example 1: Using RLDS transformations

In [None]:
# The table is unused
sars_config = reverb.structured_writer.create_config(
    sars_pattern, table='transition', conditions=[condition_first_steps])

pattern_dataset_rlds = rlds.transformations.pattern_map(
    episodes_dataset=dataset,
    configs=[sars_config],
    # By setting this to true, we don't generate transitions that mix steps
    # from two episodes.
    respect_episode_boundaries=True,
)

for transition in pattern_dataset_rlds.take(2):
  print(transition)

{'action': <tf.Tensor: shape=(6,), dtype=float32, numpy=
array([ 0.98304164, -0.30126104,  0.3770365 ,  0.53884596, -0.94180214,
       -0.158314  ], dtype=float32)>, 'next_observation': <tf.Tensor: shape=(17,), dtype=float32, numpy=
array([-0.01038213, -0.02540122,  0.33964765, -0.1585834 ,  0.11570273,
        0.25722653, -0.3806618 , -0.05179007,  1.0116358 , -0.6148436 ,
       -0.09667274,  6.886969  , -5.4096785 ,  3.4446044 ,  7.0207024 ,
       -7.125424  , -0.6964876 ], dtype=float32)>, 'observation': <tf.Tensor: shape=(17,), dtype=float32, numpy=
array([ 0.00479665, -0.02969088,  0.06049511,  0.09577212, -0.01482561,
       -0.0183125 , -0.06952589, -0.02880755,  0.18119103,  0.00390713,
       -0.07221037, -0.06411345,  0.02970835, -0.12453649, -0.06031511,
       -0.02463249,  0.02966032], dtype=float32)>, 'reward': <tf.Tensor: shape=(), dtype=float32, numpy=0.6067187>}
{'action': <tf.Tensor: shape=(6,), dtype=float32, numpy=
array([ 0.9912193 , -0.8879203 , -0.97937745, -0

### Example 2: Using the PatternDataset API directly

In [None]:
# The table is unused
sars_config = reverb.structured_writer.create_config(
    sars_pattern, table='transition', conditions=[condition_first_steps])

pattern_dataset = reverb.PatternDataset(
    # We convert the dataset of episodes into a dataset of steps
    input_dataset=dataset.flat_map(lambda e: e[rlds.STEPS]),
    configs=[sars_config],
    # By setting this to true, we don't generate transitions that mix steps
    # from two episodes.
    respect_episode_boundaries=True,
    # We need to tell the dataset how to identify the end of an episode.
    is_end_of_episode=lambda step: step[rlds.IS_LAST],
)

for transition in pattern_dataset.take(2):
  print(transition)

{'action': <tf.Tensor: shape=(6,), dtype=float32, numpy=
array([ 0.98304164, -0.30126104,  0.3770365 ,  0.53884596, -0.94180214,
       -0.158314  ], dtype=float32)>, 'next_observation': <tf.Tensor: shape=(17,), dtype=float32, numpy=
array([-0.01038213, -0.02540122,  0.33964765, -0.1585834 ,  0.11570273,
        0.25722653, -0.3806618 , -0.05179007,  1.0116358 , -0.6148436 ,
       -0.09667274,  6.886969  , -5.4096785 ,  3.4446044 ,  7.0207024 ,
       -7.125424  , -0.6964876 ], dtype=float32)>, 'observation': <tf.Tensor: shape=(17,), dtype=float32, numpy=
array([ 0.00479665, -0.02969088,  0.06049511,  0.09577212, -0.01482561,
       -0.0183125 , -0.06952589, -0.02880755,  0.18119103,  0.00390713,
       -0.07221037, -0.06411345,  0.02970835, -0.12453649, -0.06031511,
       -0.02463249,  0.02966032], dtype=float32)>, 'reward': <tf.Tensor: shape=(), dtype=float32, numpy=0.6067187>}
{'action': <tf.Tensor: shape=(6,), dtype=float32, numpy=
array([ 0.9912193 , -0.8879203 , -0.97937745, -0