diff --git a/meltingpot/testing/mocks_test.py b/meltingpot/testing/mocks_test.py index 166ad345..7b94b572 100644 --- a/meltingpot/testing/mocks_test.py +++ b/meltingpot/testing/mocks_test.py @@ -30,8 +30,8 @@ def test_value_from_specs(self): ) actual = mocks._values_from_specs(specs) expected = ( - {'a': np.zeros([1, 2, 3], dtype=np.uint8) + 0}, - {'b': np.zeros([1, 2, 3], dtype=np.uint8) + 1}, + {'a': np.zeros([1, 2, 3], dtype=np.uint8)}, + {'b': np.ones([1, 2, 3], dtype=np.uint8)}, ) np.testing.assert_equal(actual, expected) @@ -44,7 +44,10 @@ def test_mock_substrate(self): num_actions=num_actions, observation_spec=observation_spec) - expected_observation = ({'a': np.uint8()}, {'a': np.uint8() + 1}) + expected_observation = ( + {'a': np.zeros([], dtype=np.uint8)}, + {'a': np.ones([], dtype=np.uint8)}, + ) expected_reward = tuple(float(n) for n in range(num_players)) with self.subTest('is_substrate'): diff --git a/meltingpot/utils/policies/saved_model_policy.py b/meltingpot/utils/policies/saved_model_policy.py index 8df461f3..01fd4ea0 100644 --- a/meltingpot/utils/policies/saved_model_policy.py +++ b/meltingpot/utils/policies/saved_model_policy.py @@ -136,7 +136,7 @@ def __init__(self, model_path: str, device_name: str = 'cpu') -> None: @contextlib.contextmanager def _build_context(self): - with self._graph.as_default(): + with self._graph.as_default(): # pylint: disable=not-context-manager with tf.compat.v1.device(self._device_name): yield