Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
87 lines (69 sloc) 3.06 KB
# Copyright 2019 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for dm_control.locomotion.soccer.pitch."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
from absl.testing import parameterized
from dm_control import composer
from dm_control.composer.variation import distributions
from dm_control.entities import props
from dm_control.locomotion.soccer import pitch as pitch_lib
from dm_control.locomotion.soccer import team as team_lib
import numpy as np
class PitchTest(parameterized.TestCase):
def _pitch_with_ball(self, pitch_size, ball_pos):
pitch = pitch_lib.Pitch(size=pitch_size)
self.assertEqual(pitch.size, pitch_size)
sphere = props.Primitive(geom_type='sphere', size=(0.1,), pos=ball_pos)
pitch.register_ball(sphere)
pitch.attach(sphere)
env = composer.Environment(
composer.NullTask(pitch), random_state=np.random.RandomState(42))
env.reset()
return pitch
def test_pitch_none_detected(self):
pitch = self._pitch_with_ball((12, 9), (0, 0, 0))
self.assertEmpty(pitch.detected_off_court())
self.assertIsNone(pitch.detected_goal())
def test_pitch_detected_off_court(self):
pitch = self._pitch_with_ball((12, 9), (20, 0, 0))
self.assertLen(pitch.detected_off_court(), 1)
self.assertIsNone(pitch.detected_goal())
def test_pitch_detected_away_goal(self):
pitch = self._pitch_with_ball((12, 9), (-9.5, 0, 1))
self.assertLen(pitch.detected_off_court(), 1)
self.assertEqual(team_lib.Team.AWAY, pitch.detected_goal())
def test_pitch_detected_home_goal(self):
pitch = self._pitch_with_ball((12, 9), (9.5, 0, 1))
self.assertLen(pitch.detected_off_court(), 1)
self.assertEqual(team_lib.Team.HOME, pitch.detected_goal())
@parameterized.parameters((True, distributions.Uniform()),
(False, distributions.Uniform()))
def test_randomize_pitch(self, keep_aspect_ratio, randomizer):
pitch = pitch_lib.RandomizedPitch(
min_size=(4, 3),
max_size=(8, 6),
randomizer=randomizer,
keep_aspect_ratio=keep_aspect_ratio)
pitch.initialize_episode_mjcf(np.random.RandomState(42))
self.assertBetween(pitch.size[0], 4, 8)
self.assertBetween(pitch.size[1], 3, 6)
if keep_aspect_ratio:
self.assertAlmostEqual((pitch.size[0] - 4) / (8. - 4.),
(pitch.size[1] - 3) / (6. - 3.))
if __name__ == '__main__':
absltest.main()
You can’t perform that action at this time.