In [None]:
# default_exp wrappers
%load_ext autoreload
%autoreload 2

In [None]:
# hide
from nbdev.showdoc import *
from nbdev.imports import *
from nbdev.export2html import *
if not os.environ.get("IN_TEST", None):
    assert IN_NOTEBOOK
    assert not IN_COLAB
    assert IN_IPYTHON

# Wrappers

> List of env wrappers. Some are copied direcctly from the `gym` package found [here](https://github.com/openai/gym)

In [None]:
# export
"""An observation wrapper that augments observations by pixel values."""
# Originally from https://github.com/openai/gym/blob/master/gym/wrappers/pixel_observation.py
# Seems to be an issue importing it from gym however.

import collections
import copy

import numpy as np

import gym
from gym import spaces
from gym import ObservationWrapper

STATE_KEY = 'state'


class PixelObservationWrapper(ObservationWrapper):
    """Augment observations by pixel values."""

    def __init__(self,
                 env,
                 pixels_only=True,
                 render_kwargs=None,
                 pixel_keys=('pixels',),
                 boxify=True):
        """Initializes a new pixel Wrapper.
        Args:
            env: The environment to wrap.
            pixels_only: If `True` (default), the original observation returned
                by the wrapped environment will be discarded, and a dictionary
                observation will only include pixels. If `False`, the
                observation dictionary will contain both the original
                observations and the pixel observations.
            render_kwargs: Optional `dict` containing keyword arguments passed
                to the `self.render` method.
            pixel_keys: Optional custom string specifying the pixel
                observation's key in the `OrderedDict` of observations.
                Defaults to 'pixels'.
            boxify: (fastrl change) Default to True where instead of a Dict, return a Box
        Raises:
            ValueError: If `env`'s observation spec is not compatible with the
                wrapper. Supported formats are a single array, or a dict of
                arrays.
            ValueError: If `env`'s observation already contains any of the
                specified `pixel_keys`.
        """

        super(PixelObservationWrapper, self).__init__(env)

        if render_kwargs is None:
            render_kwargs = {}

        for key in pixel_keys:
            render_kwargs.setdefault(key, {})

            render_mode = render_kwargs[key].pop('mode', 'rgb_array')
            assert render_mode == 'rgb_array', render_mode
            render_kwargs[key]['mode'] = 'rgb_array'

        wrapped_observation_space = env.observation_space

        if isinstance(wrapped_observation_space, spaces.Box):
            self._observation_is_dict = False
            invalid_keys = set([STATE_KEY])
        elif isinstance(wrapped_observation_space,
                        (spaces.Dict, collections.MutableMapping)):
            self._observation_is_dict = True
            invalid_keys = set(wrapped_observation_space.spaces.keys())
        else:
            raise ValueError("Unsupported observation space structure.")

        if not pixels_only:
            # Make sure that now keys in the `pixel_keys` overlap with
            # `observation_keys`
            overlapping_keys = set(pixel_keys) & set(invalid_keys)
            if overlapping_keys:
                raise ValueError("Duplicate or reserved pixel keys {!r}."
                                 .format(overlapping_keys))
            if boxify:
                raise ValueError("boxify cannot be True of pixels_only is False.")

        if pixels_only:
            self.observation_space = spaces.Dict()
        elif self._observation_is_dict:
            self.observation_space = copy.deepcopy(wrapped_observation_space)
        else:
            self.observation_space = spaces.Dict()
            self.observation_space.spaces[STATE_KEY] = wrapped_observation_space

        # Extend observation space with pixels.

        pixels_spaces = {}
        for pixel_key in pixel_keys:
            pixels = self.env.render(**render_kwargs[pixel_key])

            if np.issubdtype(pixels.dtype, np.integer):
                low, high = (0, 255)
            elif np.issubdtype(pixels.dtype, np.float):
                low, high = (-float('inf'), float('inf'))
            else:
                raise TypeError(pixels.dtype)

            pixels_space = spaces.Box(
                shape=pixels.shape, low=low, high=high, dtype=pixels.dtype)
            
            if boxify: 
                self.observation_space=pixels_space
                break
            
            pixels_spaces[pixel_key] = pixels_space

        if not boxify: self.observation_space.spaces.update(pixels_spaces)

        self._env = env
        self._pixels_only = pixels_only
        self._render_kwargs = render_kwargs
        self._pixel_keys = pixel_keys
        self._boxify=boxify

    def observation(self, observation):
        pixel_observation = self._add_pixel_observation(observation)
        return pixel_observation

    def _add_pixel_observation(self, observation):
        if not self._boxify:
            if self._pixels_only:
                observation = collections.OrderedDict()
            elif self._observation_is_dict:
                observation = type(observation)(observation)
            else:
                observation = collections.OrderedDict()
                observation[STATE_KEY] = observation

            pixel_observations = {
                pixel_key: self.env.render(**self._render_kwargs[pixel_key])
                for pixel_key in self._pixel_keys
            }

            observation.update(pixel_observations)
        else:
            observation=self.env.render(**self._render_kwargs[self._pixel_keys[0]])

        return observation

In [None]:
env=gym.make("CartPole-v1")
env.reset()
env=PixelObservationWrapper(env,boxify=False)
test_eq(env.observation_space,gym.spaces.Dict({'pixels':gym.spaces.Box(0,255,shape=(400,600,3))}))
env=gym.make("CartPole-v1")
env.reset()
env=PixelObservationWrapper(env)
test_eq(env.observation_space,gym.spaces.Box(0,255,shape=(400,600,3)))

In [None]:
# hide
from nbdev.export import *
notebook2script()
notebook2html(n_workers=0)

Converted 00_core.ipynb.
Converted 01_wrappers.ipynb.
Converted 02_callbacks.ipynb.
Converted 03_basic_agents.ipynb.
Converted 04_metrics.ipynb.
Converted 05_data_block.ipynb.
Converted 06_basic_train.ipynb.
Converted 12_a3c.a3c_data.ipynb.
Converted index.ipynb.
Converted notes.ipynb.
converting: /opt/project/fastrl/nbs/12_a3c.a3c_data.ipynb


