-
Notifications
You must be signed in to change notification settings - Fork 348
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(gry): add acrobot env and dqn config (#577)
* add acrobot env and dqn config * remove useless part and update readme * provide missed acrobot gif * fix wrong number in readme * run format script to fix format problem
- Loading branch information
Showing
8 changed files
with
191 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .acrobot_dqn_config import acrobot_dqn_config, acrobot_dqn_create_config |
55 changes: 55 additions & 0 deletions
55
dizoo/classic_control/acrobot/config/acrobot_dqn_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from easydict import EasyDict | ||
|
||
acrobot_dqn_config = dict( | ||
exp_name='acrobot_dqn_seed0', | ||
env=dict( | ||
collector_env_num=8, | ||
evaluator_env_num=8, | ||
n_evaluator_episode=8, | ||
stop_value=-60, | ||
env_id='Acrobot-v1', | ||
replay_path='acrobot_dqn_seed0/video', | ||
), | ||
policy=dict( | ||
cuda=True, | ||
model=dict( | ||
obs_shape=6, | ||
action_shape=3, | ||
encoder_hidden_size_list=[256, 256], | ||
dueling=True, | ||
), | ||
nstep=3, | ||
discount_factor=0.99, | ||
learn=dict( | ||
update_per_collect=10, | ||
batch_size=128, | ||
learning_rate=0.0001, | ||
target_update_freq=250, | ||
), | ||
collect=dict(n_sample=96, ), | ||
eval=dict(evaluator=dict(eval_freq=2000, )), | ||
other=dict( | ||
eps=dict( | ||
type='exp', | ||
start=1., | ||
end=0.05, | ||
decay=250000, | ||
), | ||
replay_buffer=dict(replay_buffer_size=100000, ), | ||
), | ||
), | ||
) | ||
acrobot_dqn_config = EasyDict(acrobot_dqn_config) | ||
main_config = acrobot_dqn_config | ||
acrobot_dqn_create_config = dict( | ||
env=dict(type='acrobot', import_names=['dizoo.classic_control.acrobot.envs.acrobot_env']), | ||
env_manager=dict(type='subprocess'), | ||
policy=dict(type='dqn'), | ||
replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']), | ||
) | ||
acrobot_dqn_create_config = EasyDict(acrobot_dqn_create_config) | ||
create_config = acrobot_dqn_create_config | ||
|
||
if __name__ == "__main__": | ||
from ding.entry import serial_pipeline | ||
serial_pipeline((main_config, create_config), seed=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .acrobot_env import AcroBotEnv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from typing import Any, List, Union, Optional | ||
import time | ||
import gym | ||
import copy | ||
import numpy as np | ||
from easydict import EasyDict | ||
from ding.envs import BaseEnv, BaseEnvTimestep | ||
from ding.torch_utils import to_ndarray, to_list | ||
from ding.utils import ENV_REGISTRY | ||
from ding.envs import ObsPlusPrevActRewWrapper | ||
|
||
|
||
@ENV_REGISTRY.register('acrobot') | ||
class AcroBotEnv(BaseEnv): | ||
|
||
def __init__(self, cfg: dict = {}) -> None: | ||
self._cfg = cfg | ||
self._init_flag = False | ||
self._replay_path = None | ||
self._observation_space = gym.spaces.Box( | ||
low=np.array([-1.0, -1.0, -1.0, -1.0, -12.57, -28.27]), | ||
high=np.array([1.0, 1.0, 1.0, 1.0, 12.57, 28.27]), | ||
shape=(6, ), | ||
dtype=np.float32 | ||
) | ||
self._action_space = gym.spaces.Discrete(3) | ||
self._action_space.seed(0) # default seed | ||
self._reward_space = gym.spaces.Box(low=-1.0, high=0.0, shape=(1, ), dtype=np.float32) | ||
|
||
def reset(self) -> np.ndarray: | ||
if not self._init_flag: | ||
self._env = gym.make('Acrobot-v1') | ||
if self._replay_path is not None: | ||
self._env = gym.wrappers.RecordVideo( | ||
self._env, | ||
video_folder=self._replay_path, | ||
episode_trigger=lambda episode_id: True, | ||
name_prefix='rl-video-{}'.format(id(self)) | ||
) | ||
self._init_flag = True | ||
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: | ||
np_seed = 100 * np.random.randint(1, 1000) | ||
self._env.seed(self._seed + np_seed) | ||
self._action_space.seed(self._seed + np_seed) | ||
elif hasattr(self, '_seed'): | ||
self._env.seed(self._seed) | ||
self._action_space.seed(self._seed) | ||
self._observation_space = self._env.observation_space | ||
self._eval_episode_return = 0 | ||
obs = self._env.reset() | ||
obs = to_ndarray(obs) | ||
return obs | ||
|
||
def close(self) -> None: | ||
if self._init_flag: | ||
self._env.close() | ||
self._init_flag = False | ||
|
||
def seed(self, seed: int, dynamic_seed: bool = True) -> None: | ||
self._seed = seed | ||
self._dynamic_seed = dynamic_seed | ||
np.random.seed(self._seed) | ||
|
||
def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: | ||
if isinstance(action, np.ndarray) and action.shape == (1, ): | ||
action = action.squeeze() # 0-dim array | ||
obs, rew, done, info = self._env.step(action) | ||
self._eval_episode_return += rew | ||
if done: | ||
info['eval_episode_return'] = self._eval_episode_return | ||
obs = to_ndarray(obs) | ||
rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transfered to a array with shape (1,) | ||
return BaseEnvTimestep(obs, rew, done, info) | ||
|
||
def enable_save_replay(self, replay_path: Optional[str] = None) -> None: | ||
if replay_path is None: | ||
replay_path = './video' | ||
self._replay_path = replay_path | ||
|
||
def random_action(self) -> np.ndarray: | ||
random_action = self.action_space.sample() | ||
random_action = to_ndarray([random_action], dtype=np.int64) | ||
return random_action | ||
|
||
@property | ||
def observation_space(self) -> gym.spaces.Space: | ||
return self._observation_space | ||
|
||
@property | ||
def action_space(self) -> gym.spaces.Space: | ||
return self._action_space | ||
|
||
@property | ||
def reward_space(self) -> gym.spaces.Space: | ||
return self._reward_space | ||
|
||
def __repr__(self) -> str: | ||
return "DI-engine Acrobot Env" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import pytest | ||
import numpy as np | ||
from dizoo.classic_control.acrobot.envs import AcroBotEnv | ||
|
||
|
||
@pytest.mark.envtest | ||
class TestAcrobotEnv: | ||
|
||
def test_naive(self): | ||
env = AcroBotEnv({}) | ||
env.seed(314, dynamic_seed=False) | ||
assert env._seed == 314 | ||
obs = env.reset() | ||
assert obs.shape == (6, ) | ||
for _ in range(5): | ||
env.reset() | ||
np.random.seed(314) | ||
print('=' * 60) | ||
for i in range(10): | ||
# Both ``env.random_action()``, and utilizing ``np.random`` as well as action space, | ||
# can generate legal random action. | ||
if i < 5: | ||
random_action = np.array([env.action_space.sample()]) | ||
else: | ||
random_action = env.random_action() | ||
timestep = env.step(random_action) | ||
print(timestep) | ||
assert isinstance(timestep.obs, np.ndarray) | ||
assert isinstance(timestep.done, bool) | ||
assert timestep.obs.shape == (6, ) | ||
assert timestep.reward.shape == (1, ) | ||
assert timestep.reward >= env.reward_space.low | ||
assert timestep.reward <= env.reward_space.high | ||
print(env.observation_space, env.action_space, env.reward_space) | ||
env.close() |