Skip to content

Commit

Permalink
fix(pu): fix obs_max_scale bug in memory_env
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Mar 27, 2024
1 parent 3b75ab9 commit 29c9afd
Showing 1 changed file with 32 additions and 9 deletions.
41 changes: 32 additions & 9 deletions zoo/memory/envs/memory_lightzero_env.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import copy
import os
from datetime import datetime
from typing import List

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from ding.envs import BaseEnv, BaseEnvTimestep
from ding.torch_utils import to_ndarray
from ding.utils import ENV_REGISTRY
from easydict import EasyDict

from PIL import Image
import matplotlib.pyplot as plt
import os
from matplotlib import animation


@ENV_REGISTRY.register('memory_lightzero')
Expand Down Expand Up @@ -42,13 +42,17 @@ class MemoryEnvLightZero(BaseEnv):
crop=True, # Whether to crop the observation
max_frames={
"explore": 15,
# NOTE: "explore" should >=2, otherwise the agent won't be able to see the target color or key.
"distractor": 30,
"reward": 15
}, # Maximum frames per phase
save_replay=False, # Whether to save GIF replay
render=False, # Whether to enable real-time rendering
scale_observation=True, # Whether to scale the observation to [0, 1]
flate_observation=False, # Whether to flatten the observation
# obs_max_scale=107, # Maximum value of the observation, for key_to_door
# obs_max_scale=101, # Maximum value of the observation, for visual_match
obs_max_scale=100, # Maximum value of the observation
)

@classmethod
Expand All @@ -69,6 +73,7 @@ def __init__(self, cfg: dict) -> None:
self._save_replay = cfg.save_replay
self._render = cfg.render
self._gif_images = []
self.obs_max_scale = cfg.obs_max_scale

def reset(self) -> np.ndarray:
"""
Expand Down Expand Up @@ -101,7 +106,7 @@ def reset(self) -> np.ndarray:
EXPLORE_GRID=PASSIVE_EXPLORE_GRID,
)
elif self._cfg.env_id == 'key_to_door':
from zoo.memory.envs.pycolab_tvt.key_to_door import Game, REWARD_GRID_SR, MAX_FRAMES_PER_PHASE_SR
from zoo.memory.envs.pycolab_tvt.key_to_door import Game, REWARD_GRID_SR
self._game = Game(
self._rng,
num_apples=self._cfg.num_apples,
Expand All @@ -118,7 +123,7 @@ def reset(self) -> np.ndarray:
if self._cfg.scale_observation:
self._observation_space = gym.spaces.Box(0, 1, shape=(1, 5, 5), dtype='float32')
else:
self._observation_space = gym.spaces.Box(0, 1000, shape=(1, 5, 5), dtype='int64')
self._observation_space = gym.spaces.Box(0, self.obs_max_scale, shape=(1, 5, 5), dtype='int64')
self._action_space = gym.spaces.Discrete(self._game.num_actions)
self._reward_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(1,), dtype=np.float32)

Expand All @@ -130,12 +135,13 @@ def reset(self) -> np.ndarray:
obs = to_ndarray(obs, dtype=np.float32)
action_mask = np.ones(self.action_space.n, 'int8')
if self._cfg.scale_observation:
obs = obs / 1000
obs = obs / self.obs_max_scale
if self._cfg.flate_observation:
obs = obs.flatten()
obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}

self._gif_images = []
self._gif_images_numpy = []

return obs

Expand Down Expand Up @@ -168,6 +174,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
info['eval_episode_return'] = info['success']
print(f'episode seed:{self._seed} done! self.episode_reward_list is: {self.episode_reward_list}')

# print(f"Step: {self._current_step}, Action: {action}, Reward: {reward}, Observation: {observation}, Done: {done}, Info: {info}") # TODO
observation = to_ndarray(observation, dtype=np.float32)
reward = to_ndarray([reward])
action_mask = np.ones(self.action_space.n, 'int8')
Expand All @@ -185,6 +192,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:

if self._save_replay:
self._gif_images.append(img)
self._gif_images_numpy.append(obs_rgb)

if self._render:
plt.imshow(img)
Expand All @@ -196,11 +204,13 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
gif_dir = os.path.join(os.path.dirname(__file__), 'replay')
os.makedirs(gif_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
gif_file = os.path.join(gif_dir, f'episode_{self._current_step}_{timestamp}.gif')
gif_file = os.path.join(gif_dir, f'episode_seed{self._seed}_len{self._current_step}_{timestamp}.gif')
self._gif_images[0].save(gif_file, save_all=True, append_images=self._gif_images[1:], duration=100, loop=0)
# self.display_frames_as_gif(self._gif_images_numpy, gif_file) # the alternative way to save gif
print(f'saved replay to {gif_file}')

if self._cfg.scale_observation:
observation = observation / 1000
observation = observation / self.obs_max_scale
if self._cfg.flate_observation:
observation = observation.flatten()
observation = {'observation': observation, 'action_mask': action_mask, 'to_play': -1}
Expand All @@ -216,6 +226,8 @@ def random_action(self) -> np.ndarray:
return random_action

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
# def seed(self, seed: int, dynamic_seed: bool = False) -> None: # TODO

"""
Set the seed for the environment's random number generator. Can handle both static and dynamic seeding.
"""
Expand All @@ -230,6 +242,17 @@ def close(self) -> None:
"""
self._init_flag = False

@staticmethod
def display_frames_as_gif(frames: list, path: str) -> None:
patch = plt.imshow(frames[0])
plt.axis('off')

def animate(i):
patch.set_data(frames[i])

anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=5)
anim.save(path, writer='imagemagick', fps=20)

@property
def observation_space(self) -> gym.spaces.Space:
return self._observation_space
Expand Down

0 comments on commit 29c9afd

Please sign in to comment.