Skip to content
Permalink
Browse files

Replace environment interface code with dm_env import.

PiperOrigin-RevId: 260989812
  • Loading branch information...
yotam authored and alimuldal committed Jul 31, 2019
1 parent 09b27a0 commit 57a1a95559312e525e9af0dc4f2c1bdd5eda1bba
@@ -26,11 +26,10 @@
from dm_control import mjcf
from dm_control.composer import observation
from dm_control.rl import control
import dm_env
import numpy as np
from six.moves import range

from dm_control.rl import environment

warnings.simplefilter('always', DeprecationWarning)

_STEPS_LOGGING_INTERVAL = 10000
@@ -264,7 +263,7 @@ def control_timestep(self):
return self.task.control_timestep


class Environment(_CommonEnvironment, environment.Base):
class Environment(_CommonEnvironment, dm_env.Environment):
"""Reinforcement learning environment for Composer tasks."""

def __init__(self, task, time_limit=float('inf'), random_state=None,
@@ -325,8 +324,8 @@ def _reset_attempt(self):
self._hooks.initialize_episode(self._physics_proxy, self._random_state)
self._observation_updater.reset(self._physics_proxy, self._random_state)
self._reset_next_step = False
return environment.TimeStep(
step_type=environment.StepType.FIRST,
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=None,
discount=None,
observation=self._observation_updater.get_observation())
@@ -339,7 +338,7 @@ def step_spec(self):
if (self._task.get_reward_spec() is None or
self._task.get_discount_spec() is None):
raise NotImplementedError
return environment.TimeStep(
return dm_env.TimeStep(
step_type=None,
reward=self._task.get_reward_spec(),
discount=self._task.get_discount_spec(),
@@ -393,12 +392,10 @@ def step(self, action):
obs = self._observation_updater.get_observation()

if not terminating:
return environment.TimeStep(
environment.StepType.MID, reward, discount, obs)
return dm_env.TimeStep(dm_env.StepType.MID, reward, discount, obs)
else:
self._reset_next_step = True
return environment.TimeStep(
environment.StepType.LAST, reward, discount, obs)
return dm_env.TimeStep(dm_env.StepType.LAST, reward, discount, obs)

def action_spec(self):
"""Returns the action specification for this environment."""
@@ -409,10 +406,11 @@ def reward_spec(self):
This will be the output of `self.task.reward_spec()` if it is not None,
otherwise it will be the default spec returned by
`environment.Base.reward_spec()`.
`dm_env.Environment.reward_spec()`.
Returns:
An `ArraySpec`, or a nested dict, list or tuple of `ArraySpec`s.
A `specs.Array` instance, or a nested dict, list or tuple of
`specs.Array`s.
"""
task_reward_spec = self._task.get_reward_spec()
if task_reward_spec is not None:
@@ -425,10 +423,11 @@ def discount_spec(self):
This will be the output of `self.task.discount_spec()` if it is not None,
otherwise it will be the default spec returned by
`environment.Base.discount_spec()`.
`dm_env.Environment.discount_spec()`.
Returns:
An `ArraySpec`, or a nested dict, list or tuple of `ArraySpec`s.
A `specs.Array` instance, or a nested dict, list or tuple of
`specs.Array`s.
"""
task_discount_spec = self._task.get_discount_spec()
if task_discount_spec is not None:
@@ -440,7 +439,7 @@ def observation_spec(self):
"""Returns the observation specification for this environment.
Returns:
An `OrderedDict` mapping observation name to `ArraySpec` containing
An `OrderedDict` mapping observation name to `specs.Array` containing
observation shape and dtype.
"""
return self._observation_updater.observation_spec()
@@ -22,11 +22,10 @@
import abc
import functools

from dm_env import specs
import numpy as np
import six

from dm_control.rl import specs


AGGREGATORS = {
'min': functools.partial(np.min, axis=0),
@@ -305,7 +304,7 @@ def width(self, value):

@property
def array_spec(self):
return specs.ArraySpec(
return specs.Array(
shape=(self._height, self._width, self._n_channels), dtype=self._dtype)

def _callable(self, physics):
@@ -23,8 +23,8 @@

from dm_control import mjcf
from dm_control.composer.observation.observable import base
from dm_env import specs
import numpy as np
from dm_control.rl import specs


_BOTH_SEGMENTATION_AND_DEPTH_ENABLED = (
@@ -221,7 +221,7 @@ def segmentation(self, value):

@property
def array_spec(self):
return specs.ArraySpec(
return specs.Array(
shape=(self._height, self._width, self._n_channels), dtype=self._dtype)

def _callable(self, physics):
@@ -22,12 +22,11 @@
import collections

from dm_control.composer.observation import obs_buffer
from dm_env import specs
import numpy as np
import six
from six.moves import range

from dm_control.rl import specs

DEFAULT_BUFFER_SIZE = 1
DEFAULT_UPDATE_INTERVAL = 1
DEFAULT_DELAY = 0
@@ -140,7 +139,7 @@ def observation_spec(self):
the first call to `reset`.
Returns:
A dict mapping observation name to `ArraySpec` containing observation
A dict mapping observation name to `Array` spec containing observation
shape and dtype.
Raises:
@@ -156,10 +155,10 @@ def make_observation_spec_dict(enabled_dict):
if enabled.observable.aggregator:
aggregated = enabled.observable.aggregator(
np.zeros(enabled.buffer.shape, dtype=enabled.buffer.dtype))
spec = specs.ArraySpec(
spec = specs.Array(
shape=aggregated.shape, dtype=aggregated.dtype, name=name)
else:
spec = specs.ArraySpec(
spec = specs.Array(
shape=enabled.buffer.shape, dtype=enabled.buffer.dtype, name=name)
out_dict[name] = spec
return out_dict
@@ -30,12 +30,11 @@
from dm_control.composer.observation import fake_physics
from dm_control.composer.observation import observable
from dm_control.composer.observation import updater
from dm_env import specs
import numpy as np
import six
from six.moves import range

from dm_control.rl import specs


class DeterministicSequence(object):

@@ -70,7 +69,7 @@ def testNestedSpecsAndValues(self, list_or_tuple):

def make_spec(obs):
array = np.array(obs.observation_callable(None, None)())
return specs.ArraySpec((1,) + array.shape, array.dtype)
return specs.Array((1,) + array.shape, array.dtype)
expected_specs = list_or_tuple((
{'two': make_spec(observables[0]['two'])},
collections.OrderedDict([
@@ -25,11 +25,10 @@
import sys

from dm_control import mujoco
from dm_env import specs
import six
from six.moves import range

from dm_control.rl import specs


def _check_timesteps_divisible(control_timestep, physics_timestep):
num_steps = control_timestep / physics_timestep
@@ -187,9 +186,9 @@ def physics_steps_per_control_step(self):
self.control_timestep, self.physics_timestep)

def action_spec(self, physics):
"""Returns an `BoundedArraySpec` matching the `Physics` actuators.
"""Returns a `BoundedArray` spec matching the `Physics` actuators.
BoundedArraySpec.name should contain a tab-separated list of actuator names.
BoundedArray.name should contain a tab-separated list of actuator names.
When overloading this method, non-MuJoCo actuators should be added to the
top of the list when possible, as a matter of convention.
@@ -199,11 +198,11 @@ def action_spec(self, physics):
names = [physics.model.id2name(i, 'actuator') or str(i)
for i in range(physics.model.nu)]
action_spec = mujoco.action_spec(physics)
return specs.BoundedArraySpec(shape=action_spec.shape,
dtype=action_spec.dtype,
minimum=action_spec.minimum,
maximum=action_spec.maximum,
name='\t'.join(names))
return specs.BoundedArray(shape=action_spec.shape,
dtype=action_spec.dtype,
minimum=action_spec.minimum,
maximum=action_spec.maximum,
name='\t'.join(names))

def get_reward_spec(self):
"""Optional method to define non-scalar rewards for a `Task`."""
@@ -23,11 +23,10 @@
from dm_control.locomotion.soccer import initializers
from dm_control.locomotion.soccer import observables as observables_lib
from dm_control.locomotion.soccer import soccer_ball
from dm_env import specs
import numpy as np
from six.moves import zip

from dm_control.rl import specs

_THROW_IN_BALL_Z = 0.5


@@ -164,7 +163,7 @@ def get_reward(self, physics):

def get_reward_spec(self):
return [
specs.ArraySpec(name="reward", shape=(), dtype=np.float32)
specs.Array(name="reward", shape=(), dtype=np.float32)
for _ in self.players
]

@@ -174,7 +173,7 @@ def get_discount(self, physics):
return np.ones((), np.float32)

def get_discount_spec(self):
return specs.ArraySpec(name="discount", shape=(), dtype=np.float32)
return specs.Array(name="discount", shape=(), dtype=np.float32)

def should_terminate_episode(self, physics):
"""Returns True if a goal was scored by either team."""
@@ -27,11 +27,10 @@
from dm_control.locomotion.walkers import initializers
from dm_control.mujoco.wrapper.mjbindings import mjlib

from dm_env import specs
import numpy as np
import six

from dm_control.rl import specs

_RANGEFINDER_SCALE = 10.0
_TOUCH_THRESHOLD = 1e-3

@@ -304,7 +303,7 @@ def action_spec(self):
a.ctrlrange if a.ctrlrange is not None else (-1., 1.)
for a in self.actuators
])
return specs.BoundedArraySpec(
return specs.BoundedArray(
shape=(len(self.actuators),),
dtype=np.float,
minimum=minimum,
@@ -49,12 +49,11 @@
from dm_control.mujoco.wrapper.mjbindings import mjlib
from dm_control.mujoco.wrapper.mjbindings import types
from dm_control.rl import control as _control
from dm_env import specs

import numpy as np
import six

from dm_control.rl import specs

_FONT_STYLES = {
'normal': enums.mjtFont.mjFONT_NORMAL,
'shadow': enums.mjtFont.mjFONT_SHADOW,
@@ -873,5 +872,5 @@ def action_spec(physics):
maxima = np.full(num_actions, fill_value=np.inf, dtype=np.float)
minima[is_limited], maxima[is_limited] = control_range[is_limited].T

return specs.BoundedArraySpec(
return specs.BoundedArray(
shape=(num_actions,), dtype=np.float, minimum=minima, maximum=maxima)
@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================

"""An environment.Base subclass for control-specific environments."""
"""A dm_env.Environment subclass for control-specific environments."""

from __future__ import absolute_import
from __future__ import division
@@ -23,17 +23,16 @@
import collections
import contextlib

import dm_env
from dm_env import specs
import numpy as np
import six
from six.moves import range

from dm_control.rl import environment
from dm_control.rl import specs

FLAT_OBSERVATION_KEY = 'observations'


class Environment(environment.Base):
class Environment(dm_env.Environment):
"""Class for physics-based reinforcement learning environments."""

def __init__(self,
@@ -93,8 +92,8 @@ def reset(self):
if self._flat_observation:
observation = flatten_observation(observation)

return environment.TimeStep(
step_type=environment.StepType.FIRST,
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=None,
discount=None,
observation=observation)
@@ -125,11 +124,10 @@ def step(self, action):

if episode_over:
self._reset_next_step = True
return environment.TimeStep(
environment.StepType.LAST, reward, discount, observation)
return dm_env.TimeStep(
dm_env.StepType.LAST, reward, discount, observation)
else:
return environment.TimeStep(
environment.StepType.MID, reward, 1.0, observation)
return dm_env.TimeStep(dm_env.StepType.MID, reward, 1.0, observation)

def action_spec(self):
"""Returns the action specification for this environment."""
@@ -202,7 +200,7 @@ def compute_n_steps(control_timestep, physics_timestep, tolerance=1e-8):
def _spec_from_observation(observation):
result = collections.OrderedDict()
for key, value in six.iteritems(observation):
result[key] = specs.ArraySpec(value.shape, value.dtype, name=key)
result[key] = specs.Array(value.shape, value.dtype, name=key)
return result

# Base class definitions for objects supplied to Environment.
@@ -25,18 +25,17 @@
from absl.testing import parameterized

from dm_control.rl import control
from dm_env import specs

import mock
import numpy as np

from dm_control.rl import specs

_CONSTANT_REWARD_VALUE = 1.0
_CONSTANT_OBSERVATION = {'observations': np.asarray(_CONSTANT_REWARD_VALUE)}

_ACTION_SPEC = specs.BoundedArraySpec(
_ACTION_SPEC = specs.BoundedArray(
shape=(1,), dtype=np.float, minimum=0.0, maximum=1.0)
_OBSERVATION_SPEC = {'observations': specs.ArraySpec(shape=(), dtype=np.float)}
_OBSERVATION_SPEC = {'observations': specs.Array(shape=(), dtype=np.float)}


class EnvironmentTest(parameterized.TestCase):

0 comments on commit 57a1a95

Please sign in to comment.
You can’t perform that action at this time.