Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(zc): add carracing in box2d #575

Merged
merged 10 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dizoo/box2d/carracing/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .carracing_dqn_config import carracing_dqn_config, carracing_dqn_create_config
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
81 changes: 81 additions & 0 deletions dizoo/box2d/carracing/config/carracing_dqn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from easydict import EasyDict

nstep = 3
carracing_dqn_config = dict(
exp_name='carracing_dqn_seed0',
env=dict(
# Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
# Env number respectively for collector and evaluator.
collector_env_num=8,
evaluator_env_num=8,
env_id='CarRacing-v2',
continuous=False,
n_evaluator_episode=8,
stop_value=950,
# The path to save the game replay
# replay_path='./carracing_dqn_seed0/video',
),
policy=dict(
# Whether to use cuda for network.
cuda=True,
load_path="./carracing_seed0/ckpt/ckpt_best.pth.tar",
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
model=dict(
obs_shape=[3, 96, 96],
action_shape=5,
encoder_hidden_size_list=[128, 128, 512],
# Whether to use dueling head.
dueling=True,
),
# Reward's future discount factor, aka. gamma.
discount_factor=0.99,
# How many steps in td error.
nstep=nstep,
# learn_mode config
learn=dict(
update_per_collect=10,
batch_size=64,
learning_rate=0.001,
# Frequency of target network update.
target_update_freq=100,
),
# collect_mode config
collect=dict(
# You can use either "n_sample" or "n_episode" in collector.collect.
# Get "n_sample" samples per collect.
n_sample=64,
# Cut trajectories into pieces with length "unroll_len".
unroll_len=1,
),
# command_mode config
other=dict(
# Epsilon greedy with decay.
eps=dict(
# Decay type. Support ['exp', 'linear'].
type='exp',
start=0.95,
end=0.1,
decay=50000,
),
replay_buffer=dict(replay_buffer_size=100000, )
),
),
)
carracing_dqn_config = EasyDict(carracing_dqn_config)
main_config = carracing_dqn_config

carracing_dqn_create_config = dict(
env=dict(
type='carracing',
import_names=['dizoo.box2d.carracing.envs.carracing_env'],
),
env_manager=dict(type='subprocess'),
# env_manager=dict(type='base'),
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
policy=dict(type='dqn'),
)
carracing_dqn_create_config = EasyDict(carracing_dqn_create_config)
create_config = carracing_dqn_create_config

if __name__ == "__main__":
# or you can enter `ding -m serial -c lunarlander_dqn_config.py -s 0`
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
from ding.entry import serial_pipeline
serial_pipeline([main_config, create_config], seed=0)
1 change: 1 addition & 0 deletions dizoo/box2d/carracing/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .carracing_env import CarRacingEnv
161 changes: 161 additions & 0 deletions dizoo/box2d/carracing/envs/carracing_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import copy
import os
from typing import Optional

import gym
import numpy as np
from easydict import EasyDict

from ding.envs import BaseEnv, BaseEnvTimestep
from ding.envs import ObsPlusPrevActRewWrapper
from ding.envs.common import affine_transform, save_frames_as_gif
from ding.torch_utils import to_ndarray
from ding.utils import ENV_REGISTRY
import cv2
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved

@ENV_REGISTRY.register('carracing')
class CarRacingEnv(BaseEnv):

config = dict(
replay_path=None,
save_replay_gif=False,
replay_path_gif=None,
action_clip=False,
)

@classmethod
def default_config(cls:type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg

def __init__(self, cfg: dict) -> None:
self._cfg = cfg
self._init_flag = False
# env_id:CarRacing-v2
self._env_id = cfg.env_id
self._replay_path = None
self._replay_path_gif = cfg.replay_path_gif
self._save_replay_gif = cfg.save_replay_gif
self._save_replay_count = 0
if 'Continuous' in self._env_id:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove Continuous related codes, which is abandoned in latest gym

self._act_scale = cfg.act_scale # act_scale only works in continuous env
self._action_clip = cfg.action_clip
else:
self._act_scale = False

def reset(self) -> np.ndarray:
if not self._init_flag:
self._env = gym.make(self._cfg.env_id,continuous=self._cfg.continuous)
if self._replay_path is not None:
self._env = gym.wrappers.RecordVideo(
self._env,
video_folder=self._replay_path,
episode_trigger=lambda episode_id: True,
name_prefix='rl-video-{}'.format(id(self))
)
if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this wrapper, it is only used in some special cases.

self._env = ObsPlusPrevActRewWrapper(self._env)
# self._observation_space = self._env.observation_space
self._observation_space = gym.spaces.Box(
low = np.min(self._env.observation_space.low), high = np.max(self._env.observation_space.high),
shape = (self._env.observation_space.shape[2],
self._env.observation_space.shape[0], self._env.observation_space.shape[1]),
dtype = np.float32
)
self._action_space = self._env.action_space
self._reward_space = gym.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
)
self._init_flag = True
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._env.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
self._env.seed(self._seed)
self._eval_episode_return = 0
obs = self._env.reset()
obs = obs.astype(np.float32) / 255
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
obs = obs.transpose(2, 0, 1)
obs = to_ndarray(obs)
if self._save_replay_gif:
self._frames = []
return obs

def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False

def render(self) -> None:
self._env.render()

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)

def step(self, action: np.ndarray) -> BaseEnvTimestep:
assert isinstance(action, np.ndarray), type(action)
if action.shape == (1, ):
action = action.item() # 0-dim array
if self._act_scale:
action = affine_transform(action, action_clip=self._action_clip, min_val=-1, max_val=1)
if self._save_replay_gif:
self._frames.append(self._env.render(mode='rgb_array'))
obs, rew, done, info = self._env.step(action)
obs = obs.astype(np.float32) / 255
obs = obs.transpose(2, 0, 1)
self._eval_episode_return += rew
if done:
info['eval_episode_return'] = self._eval_episode_return
if self._save_replay_gif:
if not os.path.exists(self._replay_path_gif):
os.makedirs(self._replay_path_gif)
path = os.path.join(
self._replay_path_gif, '{}_episode_{}.gif'.format(self._env_id, self._save_replay_count)
)
save_frames_as_gif(self._frames, path)
self._save_replay_count += 1

obs = to_ndarray(obs)
rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transferred to a array with shape (1,)
return BaseEnvTimestep(obs, rew, done, info)


def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path
self._save_replay_gif = True
self._save_replay_count = 0
# this function can lead to the meaningless result
self._env = gym.wrappers.RecordVideo(
self._env,
video_folder=self._replay_path,
episode_trigger=lambda episode_id: True,
name_prefix='rl-video-{}'.format(id(self))
)

def random_action(self) -> np.ndarray:
random_action = self.action_space.sample()
if isinstance(random_action, np.ndarray):
pass
elif isinstance(random_action, int):
random_action = to_ndarray([random_action], dtype=np.int64)
return random_action

@property
def observation_space(self) -> gym.spaces.Space:
return self._observation_space

@property
def action_space(self) -> gym.spaces.Space:
return self._action_space

@property
def reward_space(self) -> gym.spaces.Space:
return self._reward_space

def __repr__(self) -> str:
return "DI-engine LunarLander Env"
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
40 changes: 40 additions & 0 deletions dizoo/box2d/carracing/envs/test_carracing_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from time import time
import pytest
import numpy as np
from easydict import EasyDict
from carracing_env import CarRacingEnv

@pytest.mark.envtest
@pytest.mark.parametrize(
'cfg', [
EasyDict({
'env_id': 'CarRacing-v1',
'act_scale': False
}),
EasyDict({
'env_id': 'CarRacingContinuous-v1',
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
'act_scale': True
})
]
)

class TestCarRacing:
def test_naive(self, cfg):
env = CarRacingEnv(cfg)
env.seed(314)
assert env._seed == 314
obs = env.reset()
assert obs.shape == (96, 96, 3)
for i in range(10):
random_action = env.random_action()
timestep = env.step(random_action)
print(timestep)
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (96, 96, 3)
PaParaZz1 marked this conversation as resolved.
Show resolved Hide resolved
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.reward_space.low
assert timestep.reward <= env.reward_space.high
# assert isinstance(timestep, tuple)
print(env.observation_space, env.action_space, env.reward_space)
env.close()