Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Observation space conversion fix #6

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions dm_control2gym/viewer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pyglet
import numpy as np


class DmControlViewer:
def __init__(self, width, height, depth=False):
self.window = pyglet.window.Window(width=width, height=height, display=None)
Expand All @@ -9,12 +10,8 @@ def __init__(self, width, height, depth=False):

self.depth = depth

if depth:
self.format = 'RGB'
self.pitch = self.width * -3
else:
self.format = 'RGB'
self.pitch = self.width * -3
self.format = 'RGB'
self.pitch = self.width * -3

def update(self, pixel):
self.window.clear()
Expand Down
66 changes: 30 additions & 36 deletions dm_control2gym/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from gym import core, spaces
from dm_control import suite
from dm_control.rl import specs
from dm_env import specs
from gym.utils import seeding
import gym
from dm_control2gym.viewer import DmControlViewer
Expand All @@ -14,15 +14,16 @@ def __init__(self, _minimum, _maximum):
super().__init__(_maximum - _minimum)
self.offset = _minimum

def convertSpec2Space(spec, clip_inf=False):

def spec2space(spec, clip_inf=False):
if spec.dtype == np.int:
# Discrete
return DmcDiscrete(spec.minimum, spec.maximum)
else:
# Box
if type(spec) is specs.ArraySpec:
if type(spec) is specs.Array:
return spaces.Box(-np.inf, np.inf, shape=spec.shape)
elif type(spec) is specs.BoundedArraySpec:
elif type(spec) is specs.BoundedArray:
_min = spec.minimum
_max = spec.maximum
if clip_inf:
Expand All @@ -39,40 +40,34 @@ def convertSpec2Space(spec, clip_inf=False):
else:
raise ValueError('Unknown spec!')

def convertOrderedDict2Space(odict):

def dict2space(odict):
if len(odict.keys()) == 1:
# no concatenation
return convertSpec2Space(list(odict.values())[0])
return spec2space(list(odict.values())[0])
else:
# concatentation
numdim = sum([np.int(np.prod(odict[key].shape)) for key in odict])
return spaces.Box(-np.inf, np.inf, shape=(numdim,))
num_dim = sum([np.int(np.prod(odict[key].shape)) for key in odict])
return spaces.Box(-np.inf, np.inf, shape=(num_dim,))


def convertObservation(spec_obs):
def convert_observation(spec_obs):
if len(spec_obs.keys()) == 1:
# no concatenation
return list(spec_obs.values())[0]
else:
# concatentation
numdim = sum([np.int(np.prod(spec_obs[key].shape)) for key in spec_obs])
space_obs = np.zeros((numdim,))
i = 0
for key in spec_obs:
space_obs[i:i+np.prod(spec_obs[key].shape)] = spec_obs[key].ravel()
i += np.prod(spec_obs[key].shape)
return space_obs
observation = [spec_obs[key] if isinstance(spec_obs[key], np.ndarray) else [spec_obs[key]] for key in spec_obs]
observation = np.concatenate(observation)
return observation


class DmControlWrapper(core.Env):

def __init__(self, domain_name, task_name, task_kwargs=None, visualize_reward=False, render_mode_list=None):

self.dmcenv = suite.load(domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs,
visualize_reward=visualize_reward)

# convert spec to space
self.action_space = convertSpec2Space(self.dmcenv.action_spec(), clip_inf=True)
self.observation_space = convertOrderedDict2Space(self.dmcenv.observation_spec())
self.action_space = spec2space(self.dmcenv.action_spec(), clip_inf=True)
self.observation_space = dict2space(self.dmcenv.observation_spec())

if render_mode_list is not None:
self.metadata['render.modes'] = list(render_mode_list.keys())
Expand All @@ -82,30 +77,30 @@ def __init__(self, domain_name, task_name, task_kwargs=None, visualize_reward=Fa

self.render_mode_list = render_mode_list

# set seed
self.timestep = None
self.pixels = None

self._seed()

def getObservation(self):
return convertObservation(self.timestep.observation)
def get_observation(self):
return convert_observation(self.timestep.observation)

def _seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]

def _reset(self):
def reset(self):
self.timestep = self.dmcenv.reset()
return self.getObservation()

def _step(self, a):
return self.get_observation()

def step(self, a):
if type(self.action_space) == DmcDiscrete:
a += self.action_space.offset
self.timestep = self.dmcenv.step(a)

return self.getObservation(), self.timestep.reward, self.timestep.last(), {}

return self.get_observation(), self.timestep.reward, self.timestep.last(), {}

def _render(self, mode='human', close=False):
def render(self, mode='human', close=False):

self.pixels = self.dmcenv.physics.render(**self.render_mode_list[mode]['render_kwargs'])
if close:
Expand All @@ -116,13 +111,12 @@ def _render(self, mode='human', close=False):
elif self.render_mode_list[mode]['show']:
self._get_viewer(mode).update(self.pixels)



if self.render_mode_list[mode]['return_pixel']:

return self.pixels

def _get_viewer(self, mode):
if self.viewer[mode] is None:
self.viewer[mode] = DmControlViewer(self.pixels.shape[1], self.pixels.shape[0], self.render_mode_list[mode]['render_kwargs']['depth'])
self.viewer[mode] = DmControlViewer(self.pixels.shape[1],
self.pixels.shape[0],
self.render_mode_list[mode]['render_kwargs']['depth'])
return self.viewer[mode]