-
Notifications
You must be signed in to change notification settings - Fork 348
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add files via upload taxi env first commit * Config Config upload * feature(wrh): first commit for taxi * feature(wrh): First commit for taxi * feature(wrh): First commit for taxi * feature(wrh): First commit for taxi * feature(wrh): Readme added * feature(wrh): taxi_dqn_config updated * feature(wrh): taxi_dqn_config updated
- Loading branch information
1 parent
13a6d45
commit d919fa5
Showing
9 changed files
with
307 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .envs import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .taxi_dqn_config import main_config, create_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from easydict import EasyDict | ||
|
||
taxi_dqn_config = dict( | ||
exp_name='taxi_seed0', | ||
env=dict( | ||
collector_env_num=8, | ||
evaluator_env_num=8, | ||
n_evaluator_episode=10, | ||
max_episode_steps=300, | ||
env_id="Taxi-v3" | ||
), | ||
policy=dict( | ||
cuda=True, | ||
load_path="./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar", | ||
model=dict( | ||
obs_shape=4, | ||
action_shape=6, | ||
encoder_hidden_size_list=[256, 128, 64] | ||
), | ||
nstep=3, | ||
discount_factor=0.98, | ||
learn=dict( | ||
update_per_collect=5, | ||
batch_size=128, | ||
learning_rate=0.001, | ||
), | ||
collect=dict(n_sample=10), | ||
eval=dict(evaluator=dict(eval_freq=5, )), | ||
other=dict( | ||
eps=dict( | ||
type="linear", | ||
start=0.8, | ||
end=0.1, | ||
decay=10000 | ||
), | ||
replay_buffer=dict(replay_buffer_size=20000,), | ||
), | ||
) | ||
) | ||
taxi_dqn_config = EasyDict(taxi_dqn_config) | ||
main_config = taxi_dqn_config | ||
|
||
taxi_dqn_create_config = dict( | ||
env=dict( | ||
type="taxi", | ||
import_names=["dizoo.taxi.envs.taxi_env"] | ||
), | ||
env_manager=dict(type='base'), | ||
policy=dict(type='dqn'), | ||
replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']), | ||
) | ||
|
||
taxi_dqn_create_config = EasyDict(taxi_dqn_create_config) | ||
create_config = taxi_dqn_create_config | ||
|
||
if __name__ == "__main__": | ||
from ding.entry import serial_pipeline | ||
serial_pipeline((main_config, create_config), max_env_step=5000, seed=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import gym | ||
import torch | ||
from easydict import EasyDict | ||
|
||
from ding.config import compile_config | ||
from ding.envs import DingEnvWrapper | ||
from ding.model import DQN | ||
from ding.policy import DQNPolicy, single_env_forward_wrapper | ||
from dizoo.taxi.config.taxi_dqn_config import create_config, main_config | ||
from dizoo.taxi.envs.taxi_env import TaxiEnv | ||
|
||
def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str) -> None: | ||
main_config.exp_name = f'taxi_dqn_seed0_deploy' | ||
cfg = compile_config(main_config, create_cfg=create_config, auto=True) | ||
env = TaxiEnv(cfg.env) | ||
env.enable_save_replay(replay_path=f'./{main_config.exp_name}/video') | ||
model = DQN(**cfg.policy.model) | ||
state_dict = torch.load(ckpt_path, map_location='cpu') | ||
model.load_state_dict(state_dict['model']) | ||
policy = DQNPolicy(cfg.policy, model=model).eval_mode | ||
forward_fn = single_env_forward_wrapper(policy.forward) | ||
obs = env.reset() | ||
returns = 0. | ||
while True: | ||
action = forward_fn(obs) | ||
obs, rew, done, info = env.step(action) | ||
returns += rew | ||
if done: | ||
break | ||
print(f'Deploy is finished, final epsiode return is: {returns}') | ||
|
||
|
||
if __name__ == "__main__": | ||
main( | ||
main_config=main_config, | ||
create_config=create_config, | ||
ckpt_path=f'./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar' | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .taxi_env import TaxiEnv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
from typing import List, Optional | ||
import os | ||
|
||
from easydict import EasyDict | ||
from gym.spaces import Space, Discrete | ||
from gym.spaces.box import Box | ||
import gym | ||
import numpy as np | ||
import imageio | ||
|
||
from ditk import logging | ||
from ding.envs.env.base_env import BaseEnv, BaseEnvTimestep | ||
from ding.torch_utils import to_ndarray | ||
from ding.utils import ENV_REGISTRY | ||
|
||
@ENV_REGISTRY.register('taxi', force_overwrite=True) | ||
class TaxiEnv(BaseEnv): | ||
|
||
def __init__(self, cfg: EasyDict) -> None: | ||
|
||
self._cfg = cfg | ||
assert self._cfg.env_id == "Taxi-v3", "Your environment name is not Taxi-v3!" | ||
self._init_flag = False | ||
self._replay_path = None | ||
self._save_replay = False | ||
self._frames = [] | ||
|
||
def reset(self) -> np.ndarray: | ||
if not self._init_flag: | ||
self._env = gym.make( | ||
id=self._cfg.env_id, | ||
render_mode="single_rgb_array", | ||
max_episode_steps=self._cfg.max_episode_steps | ||
) | ||
self._observation_space = self._env.observation_space | ||
self._action_space = self._env.action_space | ||
self._reward_space = Box( | ||
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 | ||
) | ||
self._init_flag = True | ||
self._eval_episode_return = 0 | ||
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: | ||
np_seed = 100 * np.random.randint(1, 1000) | ||
self._env_seed = self._seed + np_seed | ||
elif hasattr(self, '_seed'): | ||
self._env_seed = self._seed | ||
if hasattr(self, '_seed'): | ||
obs = self._env.reset(seed=self._env_seed) | ||
else: | ||
obs = self._env.reset() | ||
|
||
if self._save_replay: | ||
picture = self._env.render() | ||
self._frames.append(picture) | ||
self._eval_episode_return = 0. | ||
obs = self._encode_taxi(obs).astype(np.float32) | ||
return obs | ||
|
||
def close(self) -> None: | ||
if self._init_flag: | ||
self._env.close() | ||
self._init_flag = False | ||
|
||
def seed(self, seed: int, dynamic_seed: bool = True) -> None: | ||
self._seed = seed | ||
self._dynamic_seed = dynamic_seed | ||
np.random.seed(self._seed) | ||
|
||
def step(self, action: np.ndarray) -> BaseEnvTimestep: | ||
assert isinstance(action, np.ndarray), type(action) | ||
action = action.item() | ||
obs, rew, done, info = self._env.step(action) | ||
self._eval_episode_return += rew | ||
obs = self._encode_taxi(obs) | ||
rew = to_ndarray([rew]) # Transformed to an array with shape (1, ) | ||
if self._save_replay: | ||
picture = self._env.render() | ||
self._frames.append(picture) | ||
if done: | ||
info['eval_episode_return'] = self._eval_episode_return | ||
if self._save_replay: | ||
assert self._replay_path is not None, "your should have a path" | ||
path = os.path.join( | ||
self._replay_path, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count) | ||
) | ||
self.frames_to_gif(self._frames, path) | ||
self._frames = [] | ||
self._save_replay_count += 1 | ||
rew = rew.astype(np.float32) | ||
obs = obs.astype(np.float32) | ||
return BaseEnvTimestep(obs, rew, done, info) | ||
|
||
def enable_save_replay(self, replay_path: Optional[str] = None) -> None: | ||
if replay_path is None: | ||
replay_path = './video' | ||
if not os.path.exists(replay_path): | ||
os.makedirs(replay_path) | ||
self._replay_path = replay_path | ||
self._save_replay = True | ||
self._save_replay_count = 0 | ||
|
||
def random_action(self) -> np.ndarray: | ||
random_action = self.action_space.sample() | ||
if isinstance(random_action, np.ndarray): | ||
pass | ||
elif isinstance(random_action, int): | ||
random_action = to_ndarray([random_action], dtype=np.int64) | ||
elif isinstance(random_action, dict): | ||
random_action = to_ndarray(random_action) | ||
else: | ||
raise TypeError( | ||
'`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format( | ||
type(random_action), random_action | ||
) | ||
) | ||
return random_action | ||
|
||
#todo encode the state into a vector | ||
def _encode_taxi(self, obs: np.ndarray) -> np.ndarray: | ||
taxi_row, taxi_col, passenger_location, destination = self._env.unwrapped.decode(obs) | ||
return to_ndarray([taxi_row, taxi_col, passenger_location, destination]) | ||
|
||
@property | ||
def observation_space(self) -> Space: | ||
return self._observation_space | ||
|
||
@property | ||
def action_space(self) -> Space: | ||
return self._action_space | ||
|
||
@property | ||
def reward_space(self) -> Space: | ||
return self._reward_space | ||
|
||
def __repr__(self) -> str: | ||
return "DI-engine Taxi-v3 Env" | ||
|
||
@staticmethod | ||
def frames_to_gif(frames: List[imageio.core.util.Array], gif_path: str, duration: float = 0.1) -> None: | ||
""" | ||
Overview: | ||
Convert a list of frames into a GIF. | ||
Arguments: | ||
- frames (:obj:`List[imageio.core.util.Array]`): A list of frames, each frame is an image. | ||
- gif_path (:obj:`str`): The path to save the GIF file. | ||
- duration (:obj:`float`): Duration between each frame in the GIF (seconds). | ||
""" | ||
# Save all frames as temporary image files | ||
temp_image_files = [] | ||
for i, frame in enumerate(frames): | ||
temp_image_file = f"frame_{i}.png" # Temporary file name | ||
imageio.imwrite(temp_image_file, frame) # Save the frame as a PNG file | ||
temp_image_files.append(temp_image_file) | ||
|
||
# Use imageio to convert temporary image files to GIF | ||
with imageio.get_writer(gif_path, mode='I', duration=duration) as writer: | ||
for temp_image_file in temp_image_files: | ||
image = imageio.imread(temp_image_file) | ||
writer.append_data(image) | ||
|
||
# Clean up temporary image files | ||
for temp_image_file in temp_image_files: | ||
os.remove(temp_image_file) | ||
logging.info(f"GIF saved as {gif_path}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import numpy as np | ||
import pytest | ||
from easydict import EasyDict | ||
from dizoo.taxi import TaxiEnv | ||
|
||
@pytest.mark.envtest | ||
class TestTaxiEnv: | ||
|
||
def test_naive(self): | ||
env = TaxiEnv( | ||
EasyDict({ | ||
"env_id": "Taxi-v3", | ||
"max_episode_steps": 300 | ||
}) | ||
) | ||
env.seed(314, dynamic_seed=False) | ||
assert env._seed == 314 | ||
obs = env.reset() | ||
assert obs.shape == (4, ) | ||
for _ in range(5): | ||
env.reset() | ||
np.random.seed(314) | ||
print('=' * 60) | ||
for i in range(10): | ||
# Both ``env.random_action()``, and utilizing ``np.random`` as well as action space, | ||
# can generate legal random action. | ||
if i < 5: | ||
random_action = np.array([env.action_space.sample()]) | ||
else: | ||
random_action = env.random_action() | ||
timestep = env.step(random_action) | ||
print(f"Your timestep in wrapped mode is: {timestep}") | ||
assert isinstance(timestep.obs, np.ndarray) | ||
assert isinstance(timestep.done, bool) | ||
assert timestep.obs.shape == (4, ) | ||
assert timestep.reward.shape == (1, ) | ||
assert timestep.reward >= env.reward_space.low | ||
assert timestep.reward <= env.reward_space.high | ||
print(env.observation_space, env.action_space, env.reward_space) | ||
env.close() | ||
|