Skip to content
Permalink
Browse files

Expose `disable_walker_contacts` as an option in soccer loader.

Closes #105.

PiperOrigin-RevId: 255381790
  • Loading branch information...
liusiqi43 authored and alimuldal committed Jun 27, 2019
1 parent 360c48d commit db37d5038791e0392f83b8fe4f87711f4ef076fc
Showing with 34 additions and 11 deletions.
  1. +7 −2 dm_control/locomotion/soccer/__init__.py
  2. +27 −9 dm_control/locomotion/soccer/loader_test.py
@@ -60,14 +60,19 @@ def _make_players(team_size):
return home_players + away_players


def load(team_size, time_limit=45., random_state=None):
def load(team_size,
time_limit=45.,
random_state=None,
disable_walker_contacts=False):
"""Construct `team_size`-vs-`team_size` soccer environment.
Args:
team_size: Integer, the number of players per team. Must be between 1 and
11.
time_limit: Float, the maximum duration of each episode in seconds.
random_state: (optional) an int seed or `np.random.RandomState` instance.
disable_walker_contacts: (optional) if `True`, disable physical contacts
between walkers.
Returns:
A `composer.Environment` instance.
@@ -84,6 +89,6 @@ def load(team_size, time_limit=45., random_state=None):
players=_make_players(team_size),
arena=RandomizedPitch(
min_size=(32, 24), max_size=(48, 36), keep_aspect_ratio=True),
),
disable_walker_contacts=disable_walker_contacts),
time_limit=time_limit,
random_state=random_state)
@@ -29,9 +29,12 @@

class LoadTest(parameterized.TestCase):

@parameterized.parameters(1, 2)
def test_load_env(self, team_size):
env = soccer.load(team_size=team_size, time_limit=2.)
@parameterized.named_parameters(
("2vs2_nocontacts", 2, True), ("2vs2_contacts", 2, False),
("1vs1_nocontacts", 1, True), ("1vs1_contacts", 1, False))
def test_load_env(self, team_size, disable_walker_contacts):
env = soccer.load(team_size=team_size, time_limit=2.,
disable_walker_contacts=disable_walker_contacts)
action_specs = env.action_spec()

random_state = np.random.RandomState(0)
@@ -68,15 +71,30 @@ def assertSameObservation(self, expected_observation, actual_observation):
np.testing.assert_array_equal(expected_array, actual_array,
err_msg=msg)

def test_same_first_observation_if_same_seed(self):
@parameterized.parameters(True, False)
def test_same_first_observation_if_same_seed(self, disable_walker_contacts):
seed = 42
timestep_1 = soccer.load(team_size=2, random_state=seed).reset()
timestep_2 = soccer.load(team_size=2, random_state=seed).reset()
timestep_1 = soccer.load(
team_size=2,
random_state=seed,
disable_walker_contacts=disable_walker_contacts).reset()
timestep_2 = soccer.load(
team_size=2,
random_state=seed,
disable_walker_contacts=disable_walker_contacts).reset()
self.assertSameObservation(timestep_1.observation, timestep_2.observation)

def test_different_first_observation_if_different_seed(self):
timestep_1 = soccer.load(team_size=2, random_state=1).reset()
timestep_2 = soccer.load(team_size=2, random_state=2).reset()
@parameterized.parameters(True, False)
def test_different_first_observation_if_different_seed(
self, disable_walker_contacts):
timestep_1 = soccer.load(
team_size=2,
random_state=1,
disable_walker_contacts=disable_walker_contacts).reset()
timestep_2 = soccer.load(
team_size=2,
random_state=2,
disable_walker_contacts=disable_walker_contacts).reset()
try:
self.assertSameObservation(timestep_1.observation, timestep_2.observation)
except AssertionError:

0 comments on commit db37d50

Please sign in to comment.
You can’t perform that action at this time.