Skip to content

Commit

Permalink
Change CompilerEnv.step to accept a single action
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Mar 5, 2022
1 parent f8a5117 commit 78294bc
Show file tree
Hide file tree
Showing 23 changed files with 111 additions and 87 deletions.
7 changes: 4 additions & 3 deletions compiler_gym/bin/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
from compiler_gym.datasets import Dataset
from compiler_gym.envs import CompilerEnv
from compiler_gym.service.connection import ConnectionOpts
from compiler_gym.spaces import Commandline
from compiler_gym.spaces import Commandline, NamedDiscrete
from compiler_gym.util.flags.env_from_flags import env_from_flags
from compiler_gym.util.tabulate import tabulate
from compiler_gym.util.truncate import truncate
Expand Down Expand Up @@ -249,12 +249,13 @@ def print_service_capabilities(env: CompilerEnv):
],
headers=("Action", "Description"),
)
else:
print(table)
elif isinstance(action_space, NamedDiscrete):
table = tabulate(
[(a,) for a in sorted(action_space.names)],
headers=("Action",),
)
print(table)
print(table)


def main(argv):
Expand Down
25 changes: 14 additions & 11 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
import numbers
import warnings
from collections.abc import Iterable as IterableType
from copy import deepcopy
from math import isclose
from pathlib import Path
Expand Down Expand Up @@ -409,7 +408,7 @@ def state(self) -> CompilerEnvState:
)

@property
def action_space(self) -> NamedDiscrete:
def action_space(self) -> Space:
"""The current action space.
:getter: Get the current action space.
Expand Down Expand Up @@ -587,7 +586,9 @@ def fork(self) -> "CompilerEnv":
self.reset()
if actions:
logger.warning("Parent service of fork() has died, replaying state")
_, _, done, _ = self.step(actions)
done = False
for action in actions:
_, _, done, _ = self.step(action)
assert not done, "Failed to replay action sequence"

request = ForkSessionRequest(session_id=self._session_id)
Expand Down Expand Up @@ -620,7 +621,9 @@ def fork(self) -> "CompilerEnv":
# replay the state.
new_env = type(self)(**self._init_kwargs())
new_env.reset()
_, _, done, _ = new_env.step(self.actions)
done = False
for action in self.actions:
_, _, done, _ = new_env.step(action)
assert not done, "Failed to replay action sequence in forked environment"

# Create copies of the mutable reward and observation spaces. This
Expand Down Expand Up @@ -878,7 +881,7 @@ def _call_with_error(

def raw_step(
self,
actions: Iterable[int],
actions: Iterable[ActionType],
observations: Iterable[ObservationSpaceSpec],
rewards: Iterable[Reward],
) -> StepType:
Expand Down Expand Up @@ -1024,15 +1027,13 @@ def raw_step(

def step(
self,
action: Union[ActionType, Iterable[ActionType]],
action: ActionType,
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
rewards: Optional[Iterable[Union[str, Reward]]] = None,
) -> StepType:
"""Take a step.
:param action: An action, or a sequence of actions. When multiple
actions are provided the observation and reward are returned after
running all of the actions.
:param action: An action.
:param observations: A list of observation spaces to compute
observations from. If provided, this changes the :code:`observation`
Expand All @@ -1052,7 +1053,7 @@ def step(
<compiler_gym.envs.CompilerEnv.reset>` has not been called.
"""
# Coerce actions into a list.
actions = action if isinstance(action, IterableType) else [action]
actions = [action]

# Coerce observation spaces into a list of ObservationSpaceSpec instances.
if observations:
Expand Down Expand Up @@ -1169,7 +1170,9 @@ def apply(self, state: CompilerEnvState) -> None: # noqa
)

actions = self.commandline_to_actions(state.commandline)
_, _, done, info = self.step(actions)
done = False
for action in actions:
_, _, done, info = self.step(action)
if done:
raise ValueError(
f"Environment terminated with error: `{info.get('error_details')}`"
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/random_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)


@deprecated(version="0.2.1", reason="Use env.step(actions) instead")
@deprecated(version="0.2.1", reason="Use env.step(action) instead")
def replay_actions(env: CompilerEnv, action_names: List[str], outdir: Path):
return replay_actions_(env, action_names, outdir)

Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/util/gym_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

# Type hints for the values returned by gym.Env.step().
ObservationType = TypeVar("ObservationType")
ActionType = int
ActionType = TypeVar("ActionType")
RewardType = float
DoneType = bool
InfoType = JsonDictType
Expand Down
4 changes: 3 additions & 1 deletion compiler_gym/util/minimize_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def _apply_and_test(env, actions, hypothesis, flakiness) -> bool:
env.reset(benchmark=env.benchmark)
for _ in range(flakiness):
logger.debug("Applying %d actions ...", len(actions))
_, _, done, info = env.step(actions)
done = False
for action in actions:
_, _, done, info = env.step(action)
if done:
raise MinimizationError(
f"Failed to replay actions: {info.get('error_details', '')}"
Expand Down
38 changes: 21 additions & 17 deletions compiler_gym/wrappers/commandline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from typing import Dict, Iterable, List, Optional, Union

from compiler_gym.envs import CompilerEnv
from compiler_gym.spaces import Commandline, CommandlineFlag
from compiler_gym.spaces import Commandline, CommandlineFlag, Reward
from compiler_gym.util.gym_type_hints import StepType
from compiler_gym.views import ObservationSpaceSpec
from compiler_gym.wrappers.core import ActionWrapper, CompilerEnvWrapper


Expand Down Expand Up @@ -57,23 +58,26 @@ def __init__(
)

def step(self, action: int) -> StepType:
if isinstance(action, int):
end_of_episode = action == 0
action = [] if end_of_episode else action - 1
end_of_episode = action == 0
if end_of_episode:
if self.observation_space_spec:
observation_spaces: List[ObservationSpaceSpec] = [
self.observation_space_spec
]
else:
observation_spaces: List[ObservationSpaceSpec] = []
if self.reward_space:
reward_spaces: List[Reward] = [self.reward_space]
else:
reward_spaces: List[Reward] = []
observation, reward, done, info = self.env.raw_step(
[], observation_spaces, reward_spaces
)
if not done:
done = True
info["terminal_action"] = True
else:
try:
index = action.index(0)
end_of_episode = True
except ValueError:
index = len(action)
end_of_episode = False
action = [a - 1 for a in action[:index]]

observation, reward, done, info = self.env.step(action)
if end_of_episode and not done:
done = True
info["terminal_action"] = True

observation, reward, done, info = self.env.step(action - 1)
return observation, reward, done, info


Expand Down
8 changes: 3 additions & 5 deletions compiler_gym/wrappers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Iterable, Optional, Union
from typing import Optional, Union

import gym

from compiler_gym.envs import CompilerEnv
from compiler_gym.spaces.reward import Reward
from compiler_gym.util.gym_type_hints import ObservationType, StepType
from compiler_gym.util.gym_type_hints import ActionType, ObservationType, StepType
from compiler_gym.views import ObservationSpaceSpec


Expand Down Expand Up @@ -82,9 +82,7 @@ class ActionWrapper(CompilerEnvWrapper):
to allow an action space transformation.
"""

def step(
self, action: Union[int, Iterable[int]], observations=None, rewards=None
) -> StepType:
def step(self, action: ActionType, observations=None, rewards=None) -> StepType:
return self.env.step(
self.action(action), observations=observations, rewards=rewards
)
Expand Down
5 changes: 3 additions & 2 deletions compiler_gym/wrappers/time_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Iterable, Optional, Union
from typing import Optional

from compiler_gym.envs import CompilerEnv
from compiler_gym.util.gym_type_hints import ActionType
from compiler_gym.wrappers.core import CompilerEnvWrapper


Expand All @@ -31,7 +32,7 @@ def __init__(self, env: CompilerEnv, max_episode_steps: Optional[int] = None):
self._max_episode_steps = max_episode_steps
self._elapsed_steps = None

def step(self, action: Union[int, Iterable[int]], **kwargs):
def step(self, action: ActionType, **kwargs):
assert (
self._elapsed_steps is not None
), "Cannot call env.step() before calling reset()"
Expand Down
4 changes: 2 additions & 2 deletions examples/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def compute_edges(env, sequence):
for action in range(env.action_space.n):
env.reset()
reward_sum = 0.0
for action in sequence + [action]:
_, reward, _, _ = env.step(action)
for a in sequence + [action]:
_, reward, _, _ = env.step(a)
reward_sum += reward

edges.append((env_to_fingerprint(env), reward_sum))
Expand Down
6 changes: 4 additions & 2 deletions examples/llvm_autotuning/autotuners/nevergrad_.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@ def nevergrad(

def calculate_negative_reward(actions: Tuple[int]) -> float:
env.reset()
env.step(actions)
for action in actions:
env.step(action)
return -env.episode_reward

else:
# Only cache the deterministic non-runtime rewards.
@lru_cache(maxsize=int(1e4))
def calculate_negative_reward(actions: Tuple[int]) -> float:
env.reset()
env.step(actions)
for action in actions:
env.step(action)
return -env.episode_reward

params = ng.p.Choice(
Expand Down
3 changes: 2 additions & 1 deletion examples/llvm_autotuning/autotuners/opentuner_.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def __init__(self, data) -> None:
wrapped = DesiredResult(Configuration(manipulator.best_config))
manipulator.run(wrapped, None, None)
env.reset()
env.step(manipulator.serialize_actions(manipulator.best_config))
for action in manipulator.serialize_actions(manipulator.best_config):
env.step(action)


class LlvmOptFlagsTuner(MeasurementInterface):
Expand Down
4 changes: 3 additions & 1 deletion examples/llvm_autotuning/optimization_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def final_reward(self, env: LlvmEnv, runtime_count: int = 30) -> float:
actions = list(env.actions)
env.reset()
for i in range(1, 5 + 1):
_, _, done, info = env.step(actions)
done = False
for action in actions:
_, _, done, info = env.step(action)
if not done:
break
logger.warning(
Expand Down
10 changes: 3 additions & 7 deletions examples/llvm_rl/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Environment wrappers to closer replicate the MLSys'20 Autophase paper."""
from collections.abc import Iterable as IterableType
from typing import List, Union

import gym
import numpy as np

from compiler_gym.envs import CompilerEnv, LlvmEnv
from compiler_gym.util.gym_type_hints import ActionType
from compiler_gym.wrappers import (
ConstrainedCommandline,
ObservationWrapper,
Expand Down Expand Up @@ -126,11 +125,8 @@ def reset(self, *args, **kwargs):
)
return super().reset(*args, **kwargs)

def step(self, action: Union[int, List[int]], observations=None, **kwargs):
if not isinstance(action, IterableType):
action = [action]
for a in action:
self.histogram[a] += self.increment
def step(self, action: ActionType, observations=None, **kwargs):
self.histogram[action] += self.increment
return super().step(action, **kwargs)

def observation(self, observation):
Expand Down
4 changes: 3 additions & 1 deletion examples/op_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ def get_step_times(env: CompilerEnv, num_steps: int, batched=False):
# Run all actions in a single step().
steps = [env.action_space.sample() for _ in range(num_steps)]
with Timer() as timer:
_, _, done, _ = env.step(steps)
done = False
for step in steps:
_, _, done, _ = env.step(step)
if not done:
return [timer.time / num_steps] * num_steps
env.reset()
Expand Down
4 changes: 3 additions & 1 deletion examples/sensitivity_analysis/action_sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def run_one_trial(
num_warmup_steps = random.randint(0, max_warmup_steps)
warmup_actions = [env.action_space.sample() for _ in range(num_warmup_steps)]
env.reward_space = reward_space
_, _, done, _ = env.step(warmup_actions)
done = False
for action in warmup_actions:
_, _, done, _ = env.step(action)
if done:
return None
_, (reward,), done, _ = env.step(action, rewards=[reward_space])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def run_one_trial(
num_steps = random.randint(min_steps, max_steps)
warmup_actions = [env.action_space.sample() for _ in range(num_steps)]
env.reward_space = reward_space
_, _, done, _ = env.step(warmup_actions)
done = False
for action in warmup_actions:
_, _, done, _ = env.step(action)
if done:
return None
return env.episode_reward
Expand Down
10 changes: 6 additions & 4 deletions tests/llvm/fork_regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,16 @@ def test_fork_regression_test(env: LlvmEnv, test: ForkRegressionTest):
pre_fork = [env.action_space[f] for f in test.pre_fork.split()]
post_fork = [env.action_space[f] for f in test.post_fork.split()]

_, _, done, info = env.step(pre_fork)
assert not done, info
for action in pre_fork:
_, _, done, info = env.step(action)
assert not done, info

with env.fork() as fkd:
assert env.state == fkd.state # Sanity check

env.step(post_fork)
fkd.step(post_fork)
for action in post_fork:
env.step(action)
fkd.step(action)
# Verify that the environment states no longer line up.
assert env.state != fkd.state

Expand Down
Loading

0 comments on commit 78294bc

Please sign in to comment.