##### Copyright 2021 Google LLC. All Rights Reserved.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#**RLDS: Examples**
This colab provides some examples of RLDS usage based on real use cases. If you are looking for an introduction to RLDS, see the [RLDS tutorial](https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb) in Google Colab.


<table class="tfo-notebook-buttons" align="left">
  <td>
    <a href="https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_examples.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Run In Google Colab"/></a>
  </td>
</table>

#Install Modules

In [None]:
!pip install rlds[tensorflow]
!pip install tfds-nightly --upgrade
!pip install envlogger
!apt-get install libgmp-dev

##Import Modules

In [None]:
import functools
import rlds
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

#Load dataset
We can load the human dataset from the Panda Pick Place Can task of the [Robosuite collection in TFDS](https://www.tensorflow.org/datasets/catalog/overview#rlds). In these examples, we are assuming that certain fields are present in the steps, so datasets from different tasks will not be compatible.

In [None]:
dataset_config = 'human_dc29b40a' # @param { isTemplate : true}
dataset_name = f'robosuite_panda_pick_place_can/{dataset_config}'
num_episodes_to_load = 30 # @param { isTemplate: true}

# Learning from Demonstrations or Offline RL

We consider the setup where an agent needs to solve a task specified by a reward $r$. We assume a dataset of episodes with the corresponding rewards is available for training. This includes:
*  The ORL setup [[1], [2] [3]] where the agent is trained solely from a dataset of episodes collected in the environment.
* The LfD setup [[4], [5], [6], [7]] where the agent can also interact with the environment.

Using one of the two provided datasets on the Robosuite PickPlaceCan environment, a typical RLDS pipeline would include the following steps:

1. sample $K$ episodes from the dataset so the performance of the trained agent could be expressed as a function of the number of available episodes.
1. combine the observations used as an input of the agent. The Robosuite datasets include many fields in the observations and one could try to train the agent from the state or form the visual observations for example.
1. finally, convert the dataset of episodes into a dataset of transitions that can be consumed by algorithms such as SAC or TD3.

[1]:(https://arxiv.org/abs/2005.01643)
[2]:(https://arxiv.org/abs/1911.11361)
[3]:(https://arxiv.org/abs/2103.01948)
[4]:(https://arxiv.org/abs/1909.01387)
[5]:(https://arxiv.org/abs/1704.03732)
[6]:(https://arxiv.org/abs/1707.08817)
[7]:(https://arxiv.org/abs/2006.12917)

In [None]:
K = 5 # @param { isTemplate: true}
buffer_size = 30 # @param { isTemplate: true}

In [None]:
dataset = tfds.load(dataset_name, split=f'train[:{num_episodes_to_load}]')
dataset = dataset.shuffle(buffer_size, seed=42, reshuffle_each_iteration=False)
dataset = dataset.take(K)

def prepare_observation(step):
  """Filters the obseravtion to only keep the state and flattens it."""
  observation_names = ['robot0_proprio-state', 'object-state']
  step[rlds.OBSERVATION] = tf.concat(
      [step[rlds.OBSERVATION][key] for key in observation_names], axis=-1)
  return step
dataset = rlds.transformations.map_nested_steps(dataset, prepare_observation)

def batch_to_transition(batch):
  """Converts a pair of consecutive steps to a custom transition format."""
  return {'s_cur': batch[rlds.OBSERVATION][0],
          'a': batch[rlds.ACTION][0],
          'r': batch[rlds.REWARD][0],
          's_next': batch[rlds.OBSERVATION][1]}

def make_transition_dataset(episode):
  """Converts an episode of steps to a dataset of custom transitions."""
  # Create a dataset of 2-step sequences with overlap of 1.
  batched_steps = rlds.transformations.batch(episode[rlds.STEPS], size=2, shift=1)
  return batched_steps.map(batch_to_transition)

transitions_ds = dataset.flat_map(make_transition_dataset)

# Absorbing Terminal States in Imitation Learning

Imitation learning is the setup where an agent tries to imitate a behavior, as defined by some sample episodes of that behavior.
In particular, the reward is not specified.

The dataset processing pipeline requires all the different pieces seen in the learning from demonstrations setup (create a train split, assemble the observation, ...) but also has some specifics.
One specific is related to the particular role of the terminal state in imitation learning.
While in standard RL tasks, looping over the terminal states only brings zero in terms of reward, in imitation learning, making this assumption of zero reward for transitions from a terminal state to the same terminal state induces some bias in algorithms like GAIL.
One way to counter this bias was proposed in [1]. It consists in learning the reward value of the transition from the absorbing state to itself.
Implementation wise, to tell a terminal state from another state, an `absorbing` bit is added to the observation (`1` for a terminal state, `0` for a regular state). The dataset is also augmented with terminal state to terminal state transitions so the agent can learn from those transitions.

[1]:(https://arxiv.org/abs/1809.02925)

In [None]:
def duplicate_terminal_step(episode):
  """Duplicates the terminal step if the episode ends in one. Noop otherwise."""
  return rlds.transformations.concat_if_terminal(
      episode, make_extra_steps=tf.data.Dataset.from_tensors)

def convert_to_absorbing_state(step):
  padding = step[rlds.IS_TERMINAL]
  if step[rlds.IS_TERMINAL]:
    step[rlds.OBSERVATION] = tf.zeros_like(step[rlds.OBSERVATION])
    step[rlds.ACTION] = tf.zeros_like(step[rlds.ACTION])
    # This is no longer a terminal state as the episode loops indefinitely.
    step[rlds.IS_TERMINAL] = False
    step[rlds.IS_LAST] = False
  # Add the absorbing bit to the observation.
  step[rlds.OBSERVATION] = tf.concat([step[rlds.OBSERVATION], [padding]], 0)
  return step

absorbing_state_ds = rlds.transformations.apply_nested_steps(
    dataset, duplicate_terminal_step)
absorbing_state_ds = rlds.transformations.map_nested_steps(
    absorbing_state_ds, convert_to_absorbing_state)

# Offline Analysis

One significant use case we envision for RLDS is the offline analysis of collected datasets.
There is no standard offline analysis procedure as what is possible is only limited by the imagination of the users. We expose in this section a fictitious use case to illustrate how custom tags stored in a RL dataset can be processed as part of an RLDS pipeline.
Let's assume we want to generate an histogram of the returns of the episodes present in the provided dataset of human episodes on the robosuite PickPlaceCan environment. This dataset holds episodes of fixed length of size 400 but also has a tag to indicate the actual end of the task.
We consider here the histogram of returns of the variable length episodes ending on the completion tag.

In [None]:
def placed_tag_is_set(step):
  return tf.not_equal(tf.math.count_nonzero(step['tag:placed']),0)

def compute_return(steps):
  """Computes the return of the episode up to the 'placed' tag."""
  # Truncate the episode after the placed tag.
  steps = rlds.transformations.truncate_after_condition(
      steps, truncate_condition=placed_tag_is_set)
  return rlds.transformations.sum_dataset(steps, lambda step: step[rlds.REWARD])

returns_ds = dataset.map(lambda episode: compute_return(episode[rlds.STEPS]))