Skip to content

Commit

Permalink
feature(wrh): add taxi env (#799)
Browse files Browse the repository at this point in the history
* Add files via upload

taxi env first commit

* Config

Config upload

* feature(wrh): first commit for taxi

* feature(wrh): First commit for taxi

* feature(wrh): First commit for taxi

* feature(wrh): First commit for taxi

* feature(wrh): Readme added

* feature(wrh): taxi_dqn_config updated

* feature(wrh): taxi_dqn_config updated
  • Loading branch information
ruiheng123 committed May 30, 2024
1 parent 13a6d45 commit d919fa5
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp) <br> env tutorial <br> 环境指南 |
| 38 | [frozen_lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/frozen_lake/FrozenLake.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/frozen_lake) <br> env tutorial <br> 环境指南 |
| 39 | [ising_model](https://github.com/mlii/mfrl/tree/master/examples/ising_model) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/ising_env/ising_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/ising_env) <br> env tutorial <br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/ising_model_zh.html) |
| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | dizoo link <br> env tutorial <br> 环境指南 |



![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space

Expand Down
Binary file added dizoo/taxi/Taxi-v3_episode_0.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions dizoo/taxi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .envs import *
1 change: 1 addition & 0 deletions dizoo/taxi/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .taxi_dqn_config import main_config, create_config
58 changes: 58 additions & 0 deletions dizoo/taxi/config/taxi_dqn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from easydict import EasyDict

taxi_dqn_config = dict(
exp_name='taxi_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=10,
max_episode_steps=300,
env_id="Taxi-v3"
),
policy=dict(
cuda=True,
load_path="./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar",
model=dict(
obs_shape=4,
action_shape=6,
encoder_hidden_size_list=[256, 128, 64]
),
nstep=3,
discount_factor=0.98,
learn=dict(
update_per_collect=5,
batch_size=128,
learning_rate=0.001,
),
collect=dict(n_sample=10),
eval=dict(evaluator=dict(eval_freq=5, )),
other=dict(
eps=dict(
type="linear",
start=0.8,
end=0.1,
decay=10000
),
replay_buffer=dict(replay_buffer_size=20000,),
),
)
)
taxi_dqn_config = EasyDict(taxi_dqn_config)
main_config = taxi_dqn_config

taxi_dqn_create_config = dict(
env=dict(
type="taxi",
import_names=["dizoo.taxi.envs.taxi_env"]
),
env_manager=dict(type='base'),
policy=dict(type='dqn'),
replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
)

taxi_dqn_create_config = EasyDict(taxi_dqn_create_config)
create_config = taxi_dqn_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), max_env_step=5000, seed=0)
38 changes: 38 additions & 0 deletions dizoo/taxi/entry/taxi_dqn_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import gym
import torch
from easydict import EasyDict

from ding.config import compile_config
from ding.envs import DingEnvWrapper
from ding.model import DQN
from ding.policy import DQNPolicy, single_env_forward_wrapper
from dizoo.taxi.config.taxi_dqn_config import create_config, main_config
from dizoo.taxi.envs.taxi_env import TaxiEnv

def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str) -> None:
main_config.exp_name = f'taxi_dqn_seed0_deploy'
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
env = TaxiEnv(cfg.env)
env.enable_save_replay(replay_path=f'./{main_config.exp_name}/video')
model = DQN(**cfg.policy.model)
state_dict = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
policy = DQNPolicy(cfg.policy, model=model).eval_mode
forward_fn = single_env_forward_wrapper(policy.forward)
obs = env.reset()
returns = 0.
while True:
action = forward_fn(obs)
obs, rew, done, info = env.step(action)
returns += rew
if done:
break
print(f'Deploy is finished, final epsiode return is: {returns}')


if __name__ == "__main__":
main(
main_config=main_config,
create_config=create_config,
ckpt_path=f'./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar'
)
1 change: 1 addition & 0 deletions dizoo/taxi/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .taxi_env import TaxiEnv
164 changes: 164 additions & 0 deletions dizoo/taxi/envs/taxi_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import List, Optional
import os

from easydict import EasyDict
from gym.spaces import Space, Discrete
from gym.spaces.box import Box
import gym
import numpy as np
import imageio

from ditk import logging
from ding.envs.env.base_env import BaseEnv, BaseEnvTimestep
from ding.torch_utils import to_ndarray
from ding.utils import ENV_REGISTRY

@ENV_REGISTRY.register('taxi', force_overwrite=True)
class TaxiEnv(BaseEnv):

def __init__(self, cfg: EasyDict) -> None:

self._cfg = cfg
assert self._cfg.env_id == "Taxi-v3", "Your environment name is not Taxi-v3!"
self._init_flag = False
self._replay_path = None
self._save_replay = False
self._frames = []

def reset(self) -> np.ndarray:
if not self._init_flag:
self._env = gym.make(
id=self._cfg.env_id,
render_mode="single_rgb_array",
max_episode_steps=self._cfg.max_episode_steps
)
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._reward_space = Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
)
self._init_flag = True
self._eval_episode_return = 0
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._env_seed = self._seed + np_seed
elif hasattr(self, '_seed'):
self._env_seed = self._seed
if hasattr(self, '_seed'):
obs = self._env.reset(seed=self._env_seed)
else:
obs = self._env.reset()

if self._save_replay:
picture = self._env.render()
self._frames.append(picture)
self._eval_episode_return = 0.
obs = self._encode_taxi(obs).astype(np.float32)
return obs

def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)

def step(self, action: np.ndarray) -> BaseEnvTimestep:
assert isinstance(action, np.ndarray), type(action)
action = action.item()
obs, rew, done, info = self._env.step(action)
self._eval_episode_return += rew
obs = self._encode_taxi(obs)
rew = to_ndarray([rew]) # Transformed to an array with shape (1, )
if self._save_replay:
picture = self._env.render()
self._frames.append(picture)
if done:
info['eval_episode_return'] = self._eval_episode_return
if self._save_replay:
assert self._replay_path is not None, "your should have a path"
path = os.path.join(
self._replay_path, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count)
)
self.frames_to_gif(self._frames, path)
self._frames = []
self._save_replay_count += 1
rew = rew.astype(np.float32)
obs = obs.astype(np.float32)
return BaseEnvTimestep(obs, rew, done, info)

def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
if not os.path.exists(replay_path):
os.makedirs(replay_path)
self._replay_path = replay_path
self._save_replay = True
self._save_replay_count = 0

def random_action(self) -> np.ndarray:
random_action = self.action_space.sample()
if isinstance(random_action, np.ndarray):
pass
elif isinstance(random_action, int):
random_action = to_ndarray([random_action], dtype=np.int64)
elif isinstance(random_action, dict):
random_action = to_ndarray(random_action)
else:
raise TypeError(
'`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format(
type(random_action), random_action
)
)
return random_action

#todo encode the state into a vector
def _encode_taxi(self, obs: np.ndarray) -> np.ndarray:
taxi_row, taxi_col, passenger_location, destination = self._env.unwrapped.decode(obs)
return to_ndarray([taxi_row, taxi_col, passenger_location, destination])

@property
def observation_space(self) -> Space:
return self._observation_space

@property
def action_space(self) -> Space:
return self._action_space

@property
def reward_space(self) -> Space:
return self._reward_space

def __repr__(self) -> str:
return "DI-engine Taxi-v3 Env"

@staticmethod
def frames_to_gif(frames: List[imageio.core.util.Array], gif_path: str, duration: float = 0.1) -> None:
"""
Overview:
Convert a list of frames into a GIF.
Arguments:
- frames (:obj:`List[imageio.core.util.Array]`): A list of frames, each frame is an image.
- gif_path (:obj:`str`): The path to save the GIF file.
- duration (:obj:`float`): Duration between each frame in the GIF (seconds).
"""
# Save all frames as temporary image files
temp_image_files = []
for i, frame in enumerate(frames):
temp_image_file = f"frame_{i}.png" # Temporary file name
imageio.imwrite(temp_image_file, frame) # Save the frame as a PNG file
temp_image_files.append(temp_image_file)

# Use imageio to convert temporary image files to GIF
with imageio.get_writer(gif_path, mode='I', duration=duration) as writer:
for temp_image_file in temp_image_files:
image = imageio.imread(temp_image_file)
writer.append_data(image)

# Clean up temporary image files
for temp_image_file in temp_image_files:
os.remove(temp_image_file)
logging.info(f"GIF saved as {gif_path}")
41 changes: 41 additions & 0 deletions dizoo/taxi/envs/test_taxi_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import pytest
from easydict import EasyDict
from dizoo.taxi import TaxiEnv

@pytest.mark.envtest
class TestTaxiEnv:

def test_naive(self):
env = TaxiEnv(
EasyDict({
"env_id": "Taxi-v3",
"max_episode_steps": 300
})
)
env.seed(314, dynamic_seed=False)
assert env._seed == 314
obs = env.reset()
assert obs.shape == (4, )
for _ in range(5):
env.reset()
np.random.seed(314)
print('=' * 60)
for i in range(10):
# Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
# can generate legal random action.
if i < 5:
random_action = np.array([env.action_space.sample()])
else:
random_action = env.random_action()
timestep = env.step(random_action)
print(f"Your timestep in wrapped mode is: {timestep}")
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (4, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.reward_space.low
assert timestep.reward <= env.reward_space.high
print(env.observation_space, env.action_space, env.reward_space)
env.close()

0 comments on commit d919fa5

Please sign in to comment.