Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update gym -> gymnasium, ray=2.0.0 -> ray=2.5.0 #153

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/gym/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any, Mapping

import dm_env
from gym import spaces
from gymnasium import spaces
import numpy as np
import tree

Expand Down
2 changes: 1 addition & 1 deletion examples/pettingzoo/sb3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Binary to run Stable Baselines 3 agents on meltingpot substrates."""

import gym
import gymnasium as gym
from meltingpot import substrate
import stable_baselines3
from stable_baselines3.common import callbacks
Expand Down
6 changes: 3 additions & 3 deletions examples/pettingzoo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import functools

from gym import utils as gym_utils
from gymnasium import utils as gym_utils
import matplotlib.pyplot as plt
from meltingpot import substrate
from ml_collections import config_dict
Expand Down Expand Up @@ -75,7 +75,7 @@ def reset(self, seed=None):
timestep = self._env.reset()
self.agents = self.possible_agents[:]
self.num_cycles = 0
return utils.timestep_to_observations(timestep)
return utils.timestep_to_observations(timestep), {}

def step(self, action):
"""See base class."""
Expand All @@ -92,7 +92,7 @@ def step(self, action):
self.agents = []

observations = utils.timestep_to_observations(timestep)
return observations, rewards, dones, infos
return observations, rewards, dones, dones, infos
jagapiou marked this conversation as resolved.
Show resolved Hide resolved

def close(self):
"""See base class."""
Expand Down
5 changes: 2 additions & 3 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ dm_env
dm_meltingpot
dm_tree
dmlab2d
gym
gymnasium
matplotlib
ml_collections
numpy<1.23 # Needed by Ray because it uses `np.bool`.
pettingzoo>=1.22.3
ray[rllib,default]==2.0.0
ray[rllib,default]==2.5.0
stable_baselines3
supersuit>=3.7.2
torch
3 changes: 1 addition & 2 deletions examples/rllib/self_play_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def get_config(
# Gets the default training configuration
config = ppo.PPOConfig()
# Number of arenas.
# This is called num_rollout_workers in 2.2.0.
config.num_workers = num_rollout_workers
config.num_rollout_workers = num_rollout_workers
# This is to match our unroll lengths.
config.rollout_fragment_length = rollout_fragment_length
# Total (time x batch) timesteps on the learning update.
Expand Down
15 changes: 6 additions & 9 deletions examples/rllib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import dm_env
import dmlab2d
from gym import spaces
from gymnasium import spaces
from meltingpot import substrate
from meltingpot.utils.policies import policy
from ml_collections import config_dict
Expand Down Expand Up @@ -59,10 +59,10 @@ def __init__(self, env: dmlab2d.Environment):
utils.spec_to_space(self._env.action_spec()))
super().__init__()

def reset(self):
def reset(self, *args, **kwargs):
"""See base class."""
timestep = self._env.reset()
return utils.timestep_to_observations(timestep)
return utils.timestep_to_observations(timestep), {}

def step(self, action_dict):
"""See base class."""
Expand All @@ -76,7 +76,7 @@ def step(self, action_dict):
info = {}

observations = utils.timestep_to_observations(timestep)
return observations, rewards, done, info
return observations, rewards, done, done, info
jagapiou marked this conversation as resolved.
Show resolved Hide resolved

def close(self):
"""See base class."""
Expand All @@ -90,7 +90,7 @@ def get_dmlab2d_env(self):
# which modes the `render` method supports.
metadata = {'render.modes': ['rgb_array']}

def render(self, mode: str) -> np.ndarray:
def render(self) -> np.ndarray:
"""Render the environment.

This allows you to set `record_env` in your training config, to record
Expand All @@ -109,10 +109,7 @@ def render(self, mode: str) -> np.ndarray:
world_rgb = observation[0]['WORLD.RGB']

# RGB mode is used for recording videos
if mode == 'rgb_array':
return world_rgb
else:
return super().render(mode=mode)
return world_rgb

def _convert_spaces_tuple_to_dict(
self,
Expand Down
8 changes: 4 additions & 4 deletions examples/rllib/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Tests for utils.py."""

from absl.testing import absltest
from gym.spaces import discrete
from gymnasium.spaces import discrete
from meltingpot import substrate
from meltingpot.configs.substrates import commons_harvest__open

Expand Down Expand Up @@ -43,7 +43,7 @@ def test_action_space_size(self):

def test_reset_number_agents(self):
"""Test that reset() returns observations for all agents."""
obs = self._env.reset()
obs, _ = self._env.reset()
self.assertLen(obs, self._num_players)

def test_step(self):
Expand All @@ -56,7 +56,7 @@ def test_step(self):
actions['player_' + str(player_idx)] = 1
dimonenka marked this conversation as resolved.
Show resolved Hide resolved

# Step
_, rewards, _, _ = self._env.step(actions)
_, rewards, _, _, _ = self._env.step(actions)

# Check we have one reward per agent
self.assertLen(rewards, self._num_players)
Expand All @@ -68,7 +68,7 @@ def test_render_modes_metadata(self):
def test_render_rgb_array(self):
"""Test that render('rgb_array') returns the full world."""
self._env.reset()
render = self._env.render('rgb_array')
render = self._env.render()
self.assertEqual(render.shape, (144, 192, 3))


Expand Down
Loading