diff --git a/compiler_gym/bin/service.py b/compiler_gym/bin/service.py index 25f868467..3542ae22b 100644 --- a/compiler_gym/bin/service.py +++ b/compiler_gym/bin/service.py @@ -105,7 +105,7 @@ from compiler_gym.datasets import Dataset from compiler_gym.envs import CompilerEnv from compiler_gym.service.connection import ConnectionOpts -from compiler_gym.spaces import Commandline +from compiler_gym.spaces import Commandline, NamedDiscrete from compiler_gym.util.flags.env_from_flags import env_from_flags from compiler_gym.util.tabulate import tabulate from compiler_gym.util.truncate import truncate @@ -249,12 +249,13 @@ def print_service_capabilities(env: CompilerEnv): ], headers=("Action", "Description"), ) - else: + print(table) + elif isinstance(action_space, NamedDiscrete): table = tabulate( [(a,) for a in sorted(action_space.names)], headers=("Action",), ) - print(table) + print(table) def main(argv): diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index fce35565e..c0100428d 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -6,7 +6,6 @@ import logging import numbers import warnings -from collections.abc import Iterable as IterableType from copy import deepcopy from math import isclose from pathlib import Path @@ -409,7 +408,7 @@ def state(self) -> CompilerEnvState: ) @property - def action_space(self) -> NamedDiscrete: + def action_space(self) -> Space: """The current action space. :getter: Get the current action space. @@ -587,7 +586,9 @@ def fork(self) -> "CompilerEnv": self.reset() if actions: logger.warning("Parent service of fork() has died, replaying state") - _, _, done, _ = self.step(actions) + done = False + for action in actions: + _, _, done, _ = self.step(action) assert not done, "Failed to replay action sequence" request = ForkSessionRequest(session_id=self._session_id) @@ -620,7 +621,9 @@ def fork(self) -> "CompilerEnv": # replay the state. new_env = type(self)(**self._init_kwargs()) new_env.reset() - _, _, done, _ = new_env.step(self.actions) + done = False + for action in self.actions: + _, _, done, _ = new_env.step(action) assert not done, "Failed to replay action sequence in forked environment" # Create copies of the mutable reward and observation spaces. This @@ -878,7 +881,7 @@ def _call_with_error( def raw_step( self, - actions: Iterable[int], + actions: Iterable[ActionType], observations: Iterable[ObservationSpaceSpec], rewards: Iterable[Reward], ) -> StepType: @@ -1024,15 +1027,13 @@ def raw_step( def step( self, - action: Union[ActionType, Iterable[ActionType]], + action: ActionType, observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None, rewards: Optional[Iterable[Union[str, Reward]]] = None, ) -> StepType: """Take a step. - :param action: An action, or a sequence of actions. When multiple - actions are provided the observation and reward are returned after - running all of the actions. + :param action: An action. :param observations: A list of observation spaces to compute observations from. If provided, this changes the :code:`observation` @@ -1052,7 +1053,7 @@ def step( ` has not been called. """ # Coerce actions into a list. - actions = action if isinstance(action, IterableType) else [action] + actions = [action] # Coerce observation spaces into a list of ObservationSpaceSpec instances. if observations: @@ -1169,7 +1170,9 @@ def apply(self, state: CompilerEnvState) -> None: # noqa ) actions = self.commandline_to_actions(state.commandline) - _, _, done, info = self.step(actions) + done = False + for action in actions: + _, _, done, info = self.step(action) if done: raise ValueError( f"Environment terminated with error: `{info.get('error_details')}`" diff --git a/compiler_gym/random_replay.py b/compiler_gym/random_replay.py index 063ec9ee9..81be67ba2 100644 --- a/compiler_gym/random_replay.py +++ b/compiler_gym/random_replay.py @@ -15,7 +15,7 @@ ) -@deprecated(version="0.2.1", reason="Use env.step(actions) instead") +@deprecated(version="0.2.1", reason="Use env.step(action) instead") def replay_actions(env: CompilerEnv, action_names: List[str], outdir: Path): return replay_actions_(env, action_names, outdir) diff --git a/compiler_gym/util/gym_type_hints.py b/compiler_gym/util/gym_type_hints.py index cc592de45..ba7b6ef8c 100644 --- a/compiler_gym/util/gym_type_hints.py +++ b/compiler_gym/util/gym_type_hints.py @@ -9,7 +9,7 @@ # Type hints for the values returned by gym.Env.step(). ObservationType = TypeVar("ObservationType") -ActionType = int +ActionType = TypeVar("ActionType") RewardType = float DoneType = bool InfoType = JsonDictType diff --git a/compiler_gym/util/minimize_trajectory.py b/compiler_gym/util/minimize_trajectory.py index 0de687699..cb16fdfdf 100644 --- a/compiler_gym/util/minimize_trajectory.py +++ b/compiler_gym/util/minimize_trajectory.py @@ -41,7 +41,9 @@ def _apply_and_test(env, actions, hypothesis, flakiness) -> bool: env.reset(benchmark=env.benchmark) for _ in range(flakiness): logger.debug("Applying %d actions ...", len(actions)) - _, _, done, info = env.step(actions) + done = False + for action in actions: + _, _, done, info = env.step(action) if done: raise MinimizationError( f"Failed to replay actions: {info.get('error_details', '')}" diff --git a/compiler_gym/wrappers/commandline.py b/compiler_gym/wrappers/commandline.py index 30606a00f..2c71040de 100644 --- a/compiler_gym/wrappers/commandline.py +++ b/compiler_gym/wrappers/commandline.py @@ -6,8 +6,9 @@ from typing import Dict, Iterable, List, Optional, Union from compiler_gym.envs import CompilerEnv -from compiler_gym.spaces import Commandline, CommandlineFlag +from compiler_gym.spaces import Commandline, CommandlineFlag, Reward from compiler_gym.util.gym_type_hints import StepType +from compiler_gym.views import ObservationSpaceSpec from compiler_gym.wrappers.core import ActionWrapper, CompilerEnvWrapper @@ -57,23 +58,26 @@ def __init__( ) def step(self, action: int) -> StepType: - if isinstance(action, int): - end_of_episode = action == 0 - action = [] if end_of_episode else action - 1 + end_of_episode = action == 0 + if end_of_episode: + if self.observation_space_spec: + observation_spaces: List[ObservationSpaceSpec] = [ + self.observation_space_spec + ] + else: + observation_spaces: List[ObservationSpaceSpec] = [] + if self.reward_space: + reward_spaces: List[Reward] = [self.reward_space] + else: + reward_spaces: List[Reward] = [] + observation, reward, done, info = self.env.raw_step( + [], observation_spaces, reward_spaces + ) + if not done: + done = True + info["terminal_action"] = True else: - try: - index = action.index(0) - end_of_episode = True - except ValueError: - index = len(action) - end_of_episode = False - action = [a - 1 for a in action[:index]] - - observation, reward, done, info = self.env.step(action) - if end_of_episode and not done: - done = True - info["terminal_action"] = True - + observation, reward, done, info = self.env.step(action - 1) return observation, reward, done, info diff --git a/compiler_gym/wrappers/core.py b/compiler_gym/wrappers/core.py index 56bb5ecff..4b590d148 100644 --- a/compiler_gym/wrappers/core.py +++ b/compiler_gym/wrappers/core.py @@ -2,13 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Optional, Union +from typing import Optional, Union import gym from compiler_gym.envs import CompilerEnv from compiler_gym.spaces.reward import Reward -from compiler_gym.util.gym_type_hints import ObservationType, StepType +from compiler_gym.util.gym_type_hints import ActionType, ObservationType, StepType from compiler_gym.views import ObservationSpaceSpec @@ -82,9 +82,7 @@ class ActionWrapper(CompilerEnvWrapper): to allow an action space transformation. """ - def step( - self, action: Union[int, Iterable[int]], observations=None, rewards=None - ) -> StepType: + def step(self, action: ActionType, observations=None, rewards=None) -> StepType: return self.env.step( self.action(action), observations=observations, rewards=rewards ) diff --git a/compiler_gym/wrappers/time_limit.py b/compiler_gym/wrappers/time_limit.py index 743853915..2e5fda2a3 100644 --- a/compiler_gym/wrappers/time_limit.py +++ b/compiler_gym/wrappers/time_limit.py @@ -2,9 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Optional, Union +from typing import Optional from compiler_gym.envs import CompilerEnv +from compiler_gym.util.gym_type_hints import ActionType from compiler_gym.wrappers.core import CompilerEnvWrapper @@ -31,7 +32,7 @@ def __init__(self, env: CompilerEnv, max_episode_steps: Optional[int] = None): self._max_episode_steps = max_episode_steps self._elapsed_steps = None - def step(self, action: Union[int, Iterable[int]], **kwargs): + def step(self, action: ActionType, **kwargs): assert ( self._elapsed_steps is not None ), "Cannot call env.step() before calling reset()" diff --git a/examples/explore.py b/examples/explore.py index 6ae6668d0..ec6409aa7 100644 --- a/examples/explore.py +++ b/examples/explore.py @@ -159,8 +159,8 @@ def compute_edges(env, sequence): for action in range(env.action_space.n): env.reset() reward_sum = 0.0 - for action in sequence + [action]: - _, reward, _, _ = env.step(action) + for a in sequence + [action]: + _, reward, _, _ = env.step(a) reward_sum += reward edges.append((env_to_fingerprint(env), reward_sum)) diff --git a/examples/llvm_autotuning/autotuners/nevergrad_.py b/examples/llvm_autotuning/autotuners/nevergrad_.py index f7b8fd043..21da6c722 100644 --- a/examples/llvm_autotuning/autotuners/nevergrad_.py +++ b/examples/llvm_autotuning/autotuners/nevergrad_.py @@ -32,7 +32,8 @@ def nevergrad( def calculate_negative_reward(actions: Tuple[int]) -> float: env.reset() - env.step(actions) + for action in actions: + env.step(action) return -env.episode_reward else: @@ -40,7 +41,8 @@ def calculate_negative_reward(actions: Tuple[int]) -> float: @lru_cache(maxsize=int(1e4)) def calculate_negative_reward(actions: Tuple[int]) -> float: env.reset() - env.step(actions) + for action in actions: + env.step(action) return -env.episode_reward params = ng.p.Choice( diff --git a/examples/llvm_autotuning/autotuners/opentuner_.py b/examples/llvm_autotuning/autotuners/opentuner_.py index 3850de8aa..01840f0a2 100644 --- a/examples/llvm_autotuning/autotuners/opentuner_.py +++ b/examples/llvm_autotuning/autotuners/opentuner_.py @@ -93,7 +93,8 @@ def __init__(self, data) -> None: wrapped = DesiredResult(Configuration(manipulator.best_config)) manipulator.run(wrapped, None, None) env.reset() - env.step(manipulator.serialize_actions(manipulator.best_config)) + for action in manipulator.serialize_actions(manipulator.best_config): + env.step(action) class LlvmOptFlagsTuner(MeasurementInterface): diff --git a/examples/llvm_autotuning/optimization_target.py b/examples/llvm_autotuning/optimization_target.py index 58feddfc4..19b5025c0 100644 --- a/examples/llvm_autotuning/optimization_target.py +++ b/examples/llvm_autotuning/optimization_target.py @@ -68,7 +68,9 @@ def final_reward(self, env: LlvmEnv, runtime_count: int = 30) -> float: actions = list(env.actions) env.reset() for i in range(1, 5 + 1): - _, _, done, info = env.step(actions) + done = False + for action in actions: + _, _, done, info = env.step(action) if not done: break logger.warning( diff --git a/examples/llvm_rl/wrappers.py b/examples/llvm_rl/wrappers.py index 4ee1b0619..d7c0b8e14 100644 --- a/examples/llvm_rl/wrappers.py +++ b/examples/llvm_rl/wrappers.py @@ -3,13 +3,12 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Environment wrappers to closer replicate the MLSys'20 Autophase paper.""" -from collections.abc import Iterable as IterableType -from typing import List, Union import gym import numpy as np from compiler_gym.envs import CompilerEnv, LlvmEnv +from compiler_gym.util.gym_type_hints import ActionType from compiler_gym.wrappers import ( ConstrainedCommandline, ObservationWrapper, @@ -126,11 +125,8 @@ def reset(self, *args, **kwargs): ) return super().reset(*args, **kwargs) - def step(self, action: Union[int, List[int]], observations=None, **kwargs): - if not isinstance(action, IterableType): - action = [action] - for a in action: - self.histogram[a] += self.increment + def step(self, action: ActionType, observations=None, **kwargs): + self.histogram[action] += self.increment return super().step(action, **kwargs) def observation(self, observation): diff --git a/examples/op_benchmarks.py b/examples/op_benchmarks.py index faa53ca05..4304b267e 100644 --- a/examples/op_benchmarks.py +++ b/examples/op_benchmarks.py @@ -267,7 +267,9 @@ def get_step_times(env: CompilerEnv, num_steps: int, batched=False): # Run all actions in a single step(). steps = [env.action_space.sample() for _ in range(num_steps)] with Timer() as timer: - _, _, done, _ = env.step(steps) + done = False + for step in steps: + _, _, done, _ = env.step(step) if not done: return [timer.time / num_steps] * num_steps env.reset() diff --git a/examples/sensitivity_analysis/action_sensitivity_analysis.py b/examples/sensitivity_analysis/action_sensitivity_analysis.py index 52c4e409a..a10a6977c 100644 --- a/examples/sensitivity_analysis/action_sensitivity_analysis.py +++ b/examples/sensitivity_analysis/action_sensitivity_analysis.py @@ -115,7 +115,9 @@ def run_one_trial( num_warmup_steps = random.randint(0, max_warmup_steps) warmup_actions = [env.action_space.sample() for _ in range(num_warmup_steps)] env.reward_space = reward_space - _, _, done, _ = env.step(warmup_actions) + done = False + for action in warmup_actions: + _, _, done, _ = env.step(action) if done: return None _, (reward,), done, _ = env.step(action, rewards=[reward_space]) diff --git a/examples/sensitivity_analysis/benchmark_sensitivity_analysis.py b/examples/sensitivity_analysis/benchmark_sensitivity_analysis.py index 065b5bc52..d95771272 100644 --- a/examples/sensitivity_analysis/benchmark_sensitivity_analysis.py +++ b/examples/sensitivity_analysis/benchmark_sensitivity_analysis.py @@ -116,7 +116,9 @@ def run_one_trial( num_steps = random.randint(min_steps, max_steps) warmup_actions = [env.action_space.sample() for _ in range(num_steps)] env.reward_space = reward_space - _, _, done, _ = env.step(warmup_actions) + done = False + for action in warmup_actions: + _, _, done, _ = env.step(action) if done: return None return env.episode_reward diff --git a/tests/llvm/fork_regression_test.py b/tests/llvm/fork_regression_test.py index febc8b851..b7fc4316a 100644 --- a/tests/llvm/fork_regression_test.py +++ b/tests/llvm/fork_regression_test.py @@ -57,14 +57,16 @@ def test_fork_regression_test(env: LlvmEnv, test: ForkRegressionTest): pre_fork = [env.action_space[f] for f in test.pre_fork.split()] post_fork = [env.action_space[f] for f in test.post_fork.split()] - _, _, done, info = env.step(pre_fork) - assert not done, info + for action in pre_fork: + _, _, done, info = env.step(action) + assert not done, info with env.fork() as fkd: assert env.state == fkd.state # Sanity check - env.step(post_fork) - fkd.step(post_fork) + for action in post_fork: + env.step(action) + fkd.step(action) # Verify that the environment states no longer line up. assert env.state != fkd.state diff --git a/tests/llvm/llvm_env_test.py b/tests/llvm/llvm_env_test.py index a1efc22a6..badc82c1c 100644 --- a/tests/llvm/llvm_env_test.py +++ b/tests/llvm/llvm_env_test.py @@ -221,8 +221,9 @@ def test_step_multiple_actions_list(env: LlvmEnv): env.action_space.flags.index("-mem2reg"), env.action_space.flags.index("-reg2mem"), ] - _, _, done, _ = env.step(actions) - assert not done + for action in actions: + _, _, done, _ = env.step(action) + assert not done assert env.actions == actions @@ -233,8 +234,9 @@ def test_step_multiple_actions_generator(env: LlvmEnv): env.action_space.flags.index("-mem2reg"), env.action_space.flags.index("-reg2mem"), ) - _, _, done, _ = env.step(actions) - assert not done + for action in actions: + _, _, done, _ = env.step(action) + assert not done assert env.actions == [ env.action_space.flags.index("-mem2reg"), env.action_space.flags.index("-reg2mem"), diff --git a/tests/util/minimize_trajectory_test.py b/tests/util/minimize_trajectory_test.py index b9e7597e4..62ee89d2d 100644 --- a/tests/util/minimize_trajectory_test.py +++ b/tests/util/minimize_trajectory_test.py @@ -49,10 +49,9 @@ def reset(self, benchmark): self.actions = [] assert benchmark == self.benchmark - def step(self, actions): - for action in actions: - assert action in self.original_trajectory - self.actions += actions + def step(self, action): + assert action in self.original_trajectory + self.actions.append(action) return None, None, False, {} @@ -151,13 +150,9 @@ def hypothesis(env): def test_minimize_trajectory_iteratively_llvm_crc32(env): """Test trajectory minimization on a real environment.""" env.reset(benchmark="cbench-v1/crc32") - env.step( - [ - env.action_space["-mem2reg"], - env.action_space["-gvn"], - env.action_space["-reg2mem"], - ] - ) + env.step(env.action_space["-mem2reg"]) + env.step(env.action_space["-gvn"]) + env.step(env.action_space["-reg2mem"]) def hypothesis(env): return ( diff --git a/tests/wrappers/commandline_wrappers_test.py b/tests/wrappers/commandline_wrappers_test.py index dba8b509a..f85a91013 100644 --- a/tests/wrappers/commandline_wrappers_test.py +++ b/tests/wrappers/commandline_wrappers_test.py @@ -23,7 +23,9 @@ def test_commandline_with_terminal_action(env: LlvmEnv): env.reset() _, _, done, info = env.step(mem2reg_index + 1) assert not done, info - _, _, done, info = env.step([reg2mem_index + 1, reg2mem_index + 1]) + _, _, done, info = env.step(reg2mem_index + 1) + assert not done, info + _, _, done, info = env.step(reg2mem_index + 1) assert not done, info assert env.actions == [mem2reg_index, reg2mem_index, reg2mem_index] @@ -63,7 +65,8 @@ def test_constrained_action_space(env: LlvmEnv): env.reset() env.step(0) - env.step([1, 1]) + env.step(1) + env.step(1) assert env.actions == [0, 1, 1] @@ -84,7 +87,8 @@ def test_constrained_action_space_fork(env: LlvmEnv): fkd.reset() fkd.step(0) - fkd.step([1, 1]) + fkd.step(1) + fkd.step(1) assert fkd.actions == [0, 1, 1] finally: diff --git a/tests/wrappers/core_wrappers_test.py b/tests/wrappers/core_wrappers_test.py index a682fe7d2..efb49a8a8 100644 --- a/tests/wrappers/core_wrappers_test.py +++ b/tests/wrappers/core_wrappers_test.py @@ -92,7 +92,9 @@ def test_wrapped_step_multi_step(env: LlvmEnv): """Test passing a list of actions to step().""" env = CompilerEnvWrapper(env) env.reset() - env.step([0, 0, 0]) + env.step(0) + env.step(0) + env.step(0) assert env.actions == [0, 0, 0] @@ -109,11 +111,12 @@ def action(self, action): env = MyWrapper(env) env.reset() - (ir, ic), (icr, icroz), _, _ = env.step( - action=[0, 0, 0], - observations=["Ir", "IrInstructionCount"], - rewards=["IrInstructionCount", "IrInstructionCountOz"], - ) + for i in range(3): + (ir, ic), (icr, icroz), _, _ = env.step( + action=0, + observations=["Ir", "IrInstructionCount"], + rewards=["IrInstructionCount", "IrInstructionCountOz"], + ) assert isinstance(ir, str) assert isinstance(ic, int) assert isinstance(icr, float) diff --git a/tests/wrappers/time_limit_wrappers_test.py b/tests/wrappers/time_limit_wrappers_test.py index f74c76ea5..adb9bc021 100644 --- a/tests/wrappers/time_limit_wrappers_test.py +++ b/tests/wrappers/time_limit_wrappers_test.py @@ -28,7 +28,9 @@ def test_wrapped_fork_type(env: LlvmEnv): def test_wrapped_step_multi_step(env: LlvmEnv): env = TimeLimit(env, max_episode_steps=5) env.reset(benchmark="benchmark://cbench-v1/dijkstra") - env.step([0, 0, 0]) + env.step(0) + env.step(0) + env.step(0) assert env.benchmark == "benchmark://cbench-v1/dijkstra" assert env.actions == [0, 0, 0] diff --git a/www/www.py b/www/www.py index 0586b3274..8c74001ab 100644 --- a/www/www.py +++ b/www/www.py @@ -217,7 +217,7 @@ def _step(request: StepRequest) -> StepReply: if request.all_states: # Replay actions one at a time to receive incremental rewards. The # first item represents the state prior to any actions. - (instcount, autophase), _, done, info = env.step( + (instcount, autophase), _, done, info = env.raw_step( action=[], observations=[ env.observation.spaces["InstCountDict"],