Skip to content

Commit

Permalink
feature(gry): add acrobot env and dqn config (#577)
Browse files Browse the repository at this point in the history
* add acrobot env and dqn config

* remove useless part and update readme

* provide missed acrobot gif

* fix wrong number in readme

* run format script to fix format problem
  • Loading branch information
ruoyuGao committed Feb 7, 2023
1 parent 9d183d9 commit b218ea7
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 30 |[evogym](https://github.com/EvolutionGym/evogym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/evogym/evogym.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/evogym/envs) <br> [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/evogym.html) <br>环境指南 |
| 31 |[gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/gym-pybullet-drones/gym-pybullet-drones.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_pybullet_drones/envs)<br>环境指南 |
| 32 |[beergame](https://github.com/OptMLGroup/DeepBeerInventory-RL) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/beergame/beergame.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/beergame/envs)<br>环境指南 |
| 33 |[classic_control/acrobot](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/acrobot/acrobot.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/acrobot/envs)<br>环境指南 |

![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space

Expand Down
Empty file.
Binary file added dizoo/classic_control/acrobot/acrobot.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions dizoo/classic_control/acrobot/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .acrobot_dqn_config import acrobot_dqn_config, acrobot_dqn_create_config
55 changes: 55 additions & 0 deletions dizoo/classic_control/acrobot/config/acrobot_dqn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from easydict import EasyDict

acrobot_dqn_config = dict(
exp_name='acrobot_dqn_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=-60,
env_id='Acrobot-v1',
replay_path='acrobot_dqn_seed0/video',
),
policy=dict(
cuda=True,
model=dict(
obs_shape=6,
action_shape=3,
encoder_hidden_size_list=[256, 256],
dueling=True,
),
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=128,
learning_rate=0.0001,
target_update_freq=250,
),
collect=dict(n_sample=96, ),
eval=dict(evaluator=dict(eval_freq=2000, )),
other=dict(
eps=dict(
type='exp',
start=1.,
end=0.05,
decay=250000,
),
replay_buffer=dict(replay_buffer_size=100000, ),
),
),
)
acrobot_dqn_config = EasyDict(acrobot_dqn_config)
main_config = acrobot_dqn_config
acrobot_dqn_create_config = dict(
env=dict(type='acrobot', import_names=['dizoo.classic_control.acrobot.envs.acrobot_env']),
env_manager=dict(type='subprocess'),
policy=dict(type='dqn'),
replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
)
acrobot_dqn_create_config = EasyDict(acrobot_dqn_create_config)
create_config = acrobot_dqn_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0)
1 change: 1 addition & 0 deletions dizoo/classic_control/acrobot/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .acrobot_env import AcroBotEnv
98 changes: 98 additions & 0 deletions dizoo/classic_control/acrobot/envs/acrobot_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Any, List, Union, Optional
import time
import gym
import copy
import numpy as np
from easydict import EasyDict
from ding.envs import BaseEnv, BaseEnvTimestep
from ding.torch_utils import to_ndarray, to_list
from ding.utils import ENV_REGISTRY
from ding.envs import ObsPlusPrevActRewWrapper


@ENV_REGISTRY.register('acrobot')
class AcroBotEnv(BaseEnv):

def __init__(self, cfg: dict = {}) -> None:
self._cfg = cfg
self._init_flag = False
self._replay_path = None
self._observation_space = gym.spaces.Box(
low=np.array([-1.0, -1.0, -1.0, -1.0, -12.57, -28.27]),
high=np.array([1.0, 1.0, 1.0, 1.0, 12.57, 28.27]),
shape=(6, ),
dtype=np.float32
)
self._action_space = gym.spaces.Discrete(3)
self._action_space.seed(0) # default seed
self._reward_space = gym.spaces.Box(low=-1.0, high=0.0, shape=(1, ), dtype=np.float32)

def reset(self) -> np.ndarray:
if not self._init_flag:
self._env = gym.make('Acrobot-v1')
if self._replay_path is not None:
self._env = gym.wrappers.RecordVideo(
self._env,
video_folder=self._replay_path,
episode_trigger=lambda episode_id: True,
name_prefix='rl-video-{}'.format(id(self))
)
self._init_flag = True
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._env.seed(self._seed + np_seed)
self._action_space.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
self._env.seed(self._seed)
self._action_space.seed(self._seed)
self._observation_space = self._env.observation_space
self._eval_episode_return = 0
obs = self._env.reset()
obs = to_ndarray(obs)
return obs

def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)

def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
if isinstance(action, np.ndarray) and action.shape == (1, ):
action = action.squeeze() # 0-dim array
obs, rew, done, info = self._env.step(action)
self._eval_episode_return += rew
if done:
info['eval_episode_return'] = self._eval_episode_return
obs = to_ndarray(obs)
rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transfered to a array with shape (1,)
return BaseEnvTimestep(obs, rew, done, info)

def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path

def random_action(self) -> np.ndarray:
random_action = self.action_space.sample()
random_action = to_ndarray([random_action], dtype=np.int64)
return random_action

@property
def observation_space(self) -> gym.spaces.Space:
return self._observation_space

@property
def action_space(self) -> gym.spaces.Space:
return self._action_space

@property
def reward_space(self) -> gym.spaces.Space:
return self._reward_space

def __repr__(self) -> str:
return "DI-engine Acrobot Env"
35 changes: 35 additions & 0 deletions dizoo/classic_control/acrobot/envs/test_acrobot_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest
import numpy as np
from dizoo.classic_control.acrobot.envs import AcroBotEnv


@pytest.mark.envtest
class TestAcrobotEnv:

def test_naive(self):
env = AcroBotEnv({})
env.seed(314, dynamic_seed=False)
assert env._seed == 314
obs = env.reset()
assert obs.shape == (6, )
for _ in range(5):
env.reset()
np.random.seed(314)
print('=' * 60)
for i in range(10):
# Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
# can generate legal random action.
if i < 5:
random_action = np.array([env.action_space.sample()])
else:
random_action = env.random_action()
timestep = env.step(random_action)
print(timestep)
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (6, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.reward_space.low
assert timestep.reward <= env.reward_space.high
print(env.observation_space, env.action_space, env.reward_space)
env.close()

0 comments on commit b218ea7

Please sign in to comment.