Skip to content

Commit

Permalink
gRPC refactoring of actions and observations
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Feb 3, 2022
1 parent 51e3c0a commit de0cad3
Show file tree
Hide file tree
Showing 55 changed files with 2,690 additions and 1,235 deletions.
2 changes: 2 additions & 0 deletions build_tools/cmake/grpc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ function(cc_grpc_library)
--descriptor_set_in "${_DESCRIPTOR_SET_FILE}"
--grpc_out "${_HEADER_DST_DIR}"
--plugin "protoc-gen-grpc=${_GRPC_CPP_PLUGIN_EXECUTABLE}"
--experimental_allow_proto3_optional
"${_RELATIVE_PROTO_FILE}"
DEPENDS "${Protobuf_PROTOC_EXECUTABLE}" "${_DESCRIPTOR_SET_FILE}" "${_PROTO_FILE}" ${_DEPS}
VERBATIM)
Expand Down Expand Up @@ -112,6 +113,7 @@ function(py_grpc_library)
--proto_path "${CMAKE_SOURCE_DIR}"
--descriptor_set_in "${_DESCRIPTOR_SET_FILE}"
--grpc_python_out "${CMAKE_BINARY_DIR}"
--experimental_allow_proto3_optional
"${_RELATIVE_PROTO_FILE}"
DEPENDS "${Python3_EXECUTABLE}" "${_DESCRIPTOR_SET_FILE}" "${_PROTO_FILE}" ${_DEPS}
VERBATIM)
Expand Down
3 changes: 3 additions & 0 deletions build_tools/cmake/protobuf.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ function(proto_library)
COMMAND "${Protobuf_PROTOC_EXECUTABLE}"
--proto_path "${CMAKE_SOURCE_DIR}"
--descriptor_set_out "${_DST_FILE}"
--experimental_allow_proto3_optional
"${_RELATIVE_PROTO_FILE}"
DEPENDS "${Protobuf_PROTOC_EXECUTABLE}" "${_SRC_FILE}" ${_RULE_DEPS}
VERBATIM)
Expand Down Expand Up @@ -85,6 +86,7 @@ function(cc_proto_library)
--proto_path "${CMAKE_SOURCE_DIR}"
--descriptor_set_in "${_DESCRIPTOR_SET_FILE}"
--cpp_out "${_HEADER_DST_DIR}"
--experimental_allow_proto3_optional
"${_RELATIVE_PROTO_FILE}"
DEPENDS
"${Protobuf_PROTOC_EXECUTABLE}"
Expand Down Expand Up @@ -139,6 +141,7 @@ function(py_proto_library)
--proto_path "${CMAKE_SOURCE_DIR}"
--descriptor_set_in "${_DESCRIPTOR_SET_FILE}"
--python_out "${CMAKE_BINARY_DIR}"
--experimental_allow_proto3_optional
"${_RELATIVE_PROTO_FILE}"
DEPENDS
"${Protobuf_PROTOC_EXECUTABLE}"
Expand Down
30 changes: 11 additions & 19 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from math import isclose
from pathlib import Path
from time import time
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import gym
import numpy as np
Expand All @@ -30,12 +30,12 @@
SessionNotFound,
)
from compiler_gym.service.connection import ServiceIsClosed
from compiler_gym.service.proto import Action, AddBenchmarkRequest
from compiler_gym.service.proto import AddBenchmarkRequest
from compiler_gym.service.proto import Benchmark as BenchmarkProto
from compiler_gym.service.proto import (
Choice,
EndSessionReply,
EndSessionRequest,
Event,
ForkSessionReply,
ForkSessionRequest,
GetVersionReply,
Expand Down Expand Up @@ -255,9 +255,11 @@ def __init__(
rewards = rewards or [
DefaultRewardFromObservation(obs.name)
for obs in self.service.observation_spaces
if obs.default_value.WhichOneof("value")
if obs.default_observation.WhichOneof("value")
and isinstance(
getattr(obs.default_value, obs.default_value.WhichOneof("value")),
getattr(
obs.default_observation, obs.default_observation.WhichOneof("value")
),
numbers.Number,
)
]
Expand Down Expand Up @@ -294,11 +296,9 @@ def __init__(
pass

# Process the available action, observation, and reward spaces.
action_spaces = [
self.action_spaces = [
proto_to_action_space(space) for space in self.service.action_spaces
]
self.action_spaces = [a.space for a in action_spaces]
self._make_actions = [a.make_action for a in action_spaces]

self.observation = self._observation_view_type(
raw_step=self.raw_step,
Expand All @@ -315,7 +315,6 @@ def __init__(
self._versions: Optional[GetVersionReply] = None

self.action_space: Optional[Space] = None
self._make_action: Optional[Callable[[Any], Action]] = None
self.observation_space: Optional[Space] = None

# Mutable state initialized in reset().
Expand Down Expand Up @@ -429,7 +428,6 @@ def action_space(self, action_space: Optional[str]):
else 0
)
self._action_space: NamedDiscrete = self.action_spaces[index]
self._make_actions: Callable[[Any], Action] = self._make_actions[index]

@property
def benchmark(self) -> Benchmark:
Expand Down Expand Up @@ -852,9 +850,7 @@ def _call_with_error(

# If the action space has changed, update it.
if reply.HasField("new_action_space"):
self.action_space, self._make_action = proto_to_action_space(
reply.new_action_space
)
self.action_space = proto_to_action_space(reply.new_action_space)

self.reward.reset(benchmark=self.benchmark, observation_view=self.observation)
if self.reward_space:
Expand Down Expand Up @@ -926,9 +922,7 @@ def raw_step(
# Send the request to the backend service.
request = StepRequest(
session_id=self._session_id,
action=[
Action(choice=[Choice(named_discrete_value_index=a)]) for a in actions
],
action=[Event(int64_value=a) for a in actions],
observation_space=[
observation_space.index for observation_space in observations_to_compute
],
Expand Down Expand Up @@ -972,9 +966,7 @@ def raw_step(

# If the action space has changed, update it.
if reply.HasField("new_action_space"):
self.action_space, self._make_action = proto_to_action_space(
reply.action_space
)
self.action_space = proto_to_action_space(reply.new_action_space)

# Translate observations to python representations.
if len(reply.observation) != len(observations_to_compute):
Expand Down
Loading

0 comments on commit de0cad3

Please sign in to comment.