Skip to content

Commit

Permalink
Add pixel observation wrapper.
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzakka committed Aug 17, 2023
1 parent 75a5d47 commit b1c47b4
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ This codebase contains software and tasks for the benchmark, and is powered by [

## Latest Updates

- [17/08/2023] Added a [pixel wrapper](robopianist/wrappers/pixels.py) for augmenting the observation space with RGB images.
- [11/08/2023] Code to train the model-free RL policies is now public, see [robopianist-rl](https://github.com/kevinzakka/robopianist-rl).

-------
Expand Down
5 changes: 4 additions & 1 deletion robopianist/suite/tasks/piano_with_shadow_hands.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ def _add_observables(self) -> None:
# Enable hand observables.
enabled_observables = [
"joints_pos",
"position",
# NOTE(kevin): This observable was previously enabled but it is redundant
# since it is encoded in the joint positions, specifically via the forearm
# slider joints (which are in units of meters).
# "position",
]
for hand in [self.right_hand, self.left_hand]:
for obs in enabled_observables:
Expand Down
2 changes: 2 additions & 0 deletions robopianist/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.

from robopianist.wrappers.evaluation import MidiEvaluationWrapper
from robopianist.wrappers.pixels import PixelWrapper
from robopianist.wrappers.sound import PianoSoundVideoWrapper

__all__ = [
"MidiEvaluationWrapper",
"PianoSoundVideoWrapper",
"PixelWrapper",
]
68 changes: 68 additions & 0 deletions robopianist/wrappers/pixels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2023 The RoboPianist Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A wrapper for adding pixels to the observation."""


import collections
from typing import Any, Dict, Optional

import dm_env
import numpy as np
from dm_env import specs
from dm_env_wrappers import EnvironmentWrapper


class PixelWrapper(EnvironmentWrapper):
"""Adds pixel observations to the observation spec."""

def __init__(
self,
environment: dm_env.Environment,
render_kwargs: Optional[Dict[str, Any]] = None,
observation_key: str = "pixels",
) -> None:
super().__init__(environment)

self._render_kwargs = render_kwargs or {}
self._observation_key = observation_key

# Update the observation spec.
self._wrapped_observation_spec = self._environment.observation_spec()
self._observation_spec = collections.OrderedDict()
self._observation_spec.update(self._wrapped_observation_spec)
pixels = self._environment.physics.render(**self._render_kwargs)
pixels_spec = specs.Array(
shape=pixels.shape, dtype=pixels.dtype, name=self._observation_key
)
self._observation_spec[observation_key] = pixels_spec

def observation_spec(self):
return self._observation_spec

def reset(self) -> dm_env.TimeStep:
timestep = self._environment.reset()
return self._add_pixel_observation(timestep)

def step(self, action: np.ndarray) -> dm_env.TimeStep:
timestep = self._environment.step(action)
return self._add_pixel_observation(timestep)

def _add_pixel_observation(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
pixels = self._environment.physics.render(**self._render_kwargs)
return timestep._replace(
observation=collections.OrderedDict(
timestep.observation, **{self._observation_key: pixels}
)
)

0 comments on commit b1c47b4

Please sign in to comment.