From 607e87d29ce9369e3f72aafcb3a5b7c818acf356 Mon Sep 17 00:00:00 2001 From: dimonenka Date: Wed, 26 Jul 2023 17:34:26 +0300 Subject: [PATCH] update gym -> gymnasium, ray=2.0.0 -> ray=2.5.0 --- examples/gym/utils.py | 2 +- examples/pettingzoo/sb3_train.py | 2 +- examples/pettingzoo/utils.py | 6 +++--- examples/requirements.txt | 5 ++--- examples/rllib/self_play_train.py | 3 +-- examples/rllib/utils.py | 15 ++++++--------- examples/rllib/utils_test.py | 8 ++++---- 7 files changed, 18 insertions(+), 23 deletions(-) diff --git a/examples/gym/utils.py b/examples/gym/utils.py index 80677886..28404b72 100644 --- a/examples/gym/utils.py +++ b/examples/gym/utils.py @@ -16,7 +16,7 @@ from typing import Any, Mapping import dm_env -from gym import spaces +from gymnasium import spaces import numpy as np import tree diff --git a/examples/pettingzoo/sb3_train.py b/examples/pettingzoo/sb3_train.py index 7db999c7..92343bad 100644 --- a/examples/pettingzoo/sb3_train.py +++ b/examples/pettingzoo/sb3_train.py @@ -13,7 +13,7 @@ # limitations under the License. """Binary to run Stable Baselines 3 agents on meltingpot substrates.""" -import gym +import gymnasium as gym from meltingpot import substrate import stable_baselines3 from stable_baselines3.common import callbacks diff --git a/examples/pettingzoo/utils.py b/examples/pettingzoo/utils.py index 51c39d93..d54923d1 100644 --- a/examples/pettingzoo/utils.py +++ b/examples/pettingzoo/utils.py @@ -15,7 +15,7 @@ import functools -from gym import utils as gym_utils +from gymnasium import utils as gym_utils import matplotlib.pyplot as plt from meltingpot import substrate from ml_collections import config_dict @@ -75,7 +75,7 @@ def reset(self, seed=None): timestep = self._env.reset() self.agents = self.possible_agents[:] self.num_cycles = 0 - return utils.timestep_to_observations(timestep) + return utils.timestep_to_observations(timestep), {} def step(self, action): """See base class.""" @@ -92,7 +92,7 @@ def step(self, action): self.agents = [] observations = utils.timestep_to_observations(timestep) - return observations, rewards, dones, infos + return observations, rewards, dones, dones, infos def close(self): """See base class.""" diff --git a/examples/requirements.txt b/examples/requirements.txt index b422325b..bdf6b41c 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -3,12 +3,11 @@ dm_env dm_meltingpot dm_tree dmlab2d -gym +gymnasium matplotlib ml_collections -numpy<1.23 # Needed by Ray because it uses `np.bool`. pettingzoo>=1.22.3 -ray[rllib,default]==2.0.0 +ray[rllib,default]==2.5.0 stable_baselines3 supersuit>=3.7.2 torch diff --git a/examples/rllib/self_play_train.py b/examples/rllib/self_play_train.py index 0ef9275d..ce6b9cdd 100644 --- a/examples/rllib/self_play_train.py +++ b/examples/rllib/self_play_train.py @@ -56,8 +56,7 @@ def get_config( # Gets the default training configuration config = ppo.PPOConfig() # Number of arenas. - # This is called num_rollout_workers in 2.2.0. - config.num_workers = num_rollout_workers + config.num_rollout_workers = num_rollout_workers # This is to match our unroll lengths. config.rollout_fragment_length = rollout_fragment_length # Total (time x batch) timesteps on the learning update. diff --git a/examples/rllib/utils.py b/examples/rllib/utils.py index 828b23cf..2c185170 100644 --- a/examples/rllib/utils.py +++ b/examples/rllib/utils.py @@ -17,7 +17,7 @@ import dm_env import dmlab2d -from gym import spaces +from gymnasium import spaces from meltingpot import substrate from meltingpot.utils.policies import policy from ml_collections import config_dict @@ -59,10 +59,10 @@ def __init__(self, env: dmlab2d.Environment): utils.spec_to_space(self._env.action_spec())) super().__init__() - def reset(self): + def reset(self, *args, **kwargs): """See base class.""" timestep = self._env.reset() - return utils.timestep_to_observations(timestep) + return utils.timestep_to_observations(timestep), {} def step(self, action_dict): """See base class.""" @@ -76,7 +76,7 @@ def step(self, action_dict): info = {} observations = utils.timestep_to_observations(timestep) - return observations, rewards, done, info + return observations, rewards, done, done, info def close(self): """See base class.""" @@ -90,7 +90,7 @@ def get_dmlab2d_env(self): # which modes the `render` method supports. metadata = {'render.modes': ['rgb_array']} - def render(self, mode: str) -> np.ndarray: + def render(self) -> np.ndarray: """Render the environment. This allows you to set `record_env` in your training config, to record @@ -109,10 +109,7 @@ def render(self, mode: str) -> np.ndarray: world_rgb = observation[0]['WORLD.RGB'] # RGB mode is used for recording videos - if mode == 'rgb_array': - return world_rgb - else: - return super().render(mode=mode) + return world_rgb def _convert_spaces_tuple_to_dict( self, diff --git a/examples/rllib/utils_test.py b/examples/rllib/utils_test.py index 309b7fa5..2053b6e2 100644 --- a/examples/rllib/utils_test.py +++ b/examples/rllib/utils_test.py @@ -14,7 +14,7 @@ """Tests for utils.py.""" from absl.testing import absltest -from gym.spaces import discrete +from gymnasium.spaces import discrete from meltingpot import substrate from meltingpot.configs.substrates import commons_harvest__open @@ -43,7 +43,7 @@ def test_action_space_size(self): def test_reset_number_agents(self): """Test that reset() returns observations for all agents.""" - obs = self._env.reset() + obs, _ = self._env.reset() self.assertLen(obs, self._num_players) def test_step(self): @@ -56,7 +56,7 @@ def test_step(self): actions['player_' + str(player_idx)] = 1 # Step - _, rewards, _, _ = self._env.step(actions) + _, rewards, _, _, _ = self._env.step(actions) # Check we have one reward per agent self.assertLen(rewards, self._num_players) @@ -68,7 +68,7 @@ def test_render_modes_metadata(self): def test_render_rgb_array(self): """Test that render('rgb_array') returns the full world.""" self._env.reset() - render = self._env.render('rgb_array') + render = self._env.render() self.assertEqual(render.shape, (144, 192, 3))