Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make action spaces and action service message conversion configurable in ClientServiceCompilerEnv #641

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions compiler_gym/service/client_service_compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from math import isclose
from pathlib import Path
from time import time
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
from deprecated.sphinx import deprecated
Expand All @@ -31,7 +31,7 @@
SessionNotFound,
)
from compiler_gym.service.connection import ServiceIsClosed
from compiler_gym.service.proto import AddBenchmarkRequest
from compiler_gym.service.proto import ActionSpace, AddBenchmarkRequest
from compiler_gym.service.proto import Benchmark as BenchmarkProto
from compiler_gym.service.proto import (
EndSessionReply,
Expand All @@ -47,7 +47,7 @@
StartSessionRequest,
StepReply,
StepRequest,
proto_to_action_space,
py_converters,
)
from compiler_gym.spaces import DefaultRewardFromObservation, NamedDiscrete, Reward
from compiler_gym.util.gym_type_hints import (
Expand Down Expand Up @@ -82,6 +82,27 @@ def _wrapped_step(
raise


class ServiceMessageConverters:
ChrisCummins marked this conversation as resolved.
Show resolved Hide resolved
action_space_converter: Callable[[ActionSpace], Space]
action_converter: Callable[[ActionType], Event]

def __init__(
self,
action_space_converter: Optional[Callable[[ActionSpace], Space]] = None,
action_converter: Optional[Callable[[Any], Event]] = None,
):
self.action_space_converter = (
py_converters.make_message_default_converter()
if action_space_converter is None
else action_space_converter
)
self.action_converter = (
py_converters.to_event_message_default_converter()
if action_converter is None
else action_converter
)


class ClientServiceCompilerEnv(CompilerEnv):
"""Implementation using gRPC for a client-server communication.

Expand All @@ -106,6 +127,7 @@ def __init__(
reward_space: Optional[Union[str, Reward]] = None,
action_space: Optional[str] = None,
derived_observation_spaces: Optional[List[Dict[str, Any]]] = None,
service_message_converters: ServiceMessageConverters = None,
connection_settings: Optional[ConnectionOpts] = None,
service_connection: Optional[CompilerGymServiceConnection] = None,
logger: Optional[logging.Logger] = None,
Expand Down Expand Up @@ -156,6 +178,8 @@ def __init__(
passed to :meth:`env.observation.add_derived_space()
<compiler_gym.views.observation.Observation.add_derived_space>`.

:param service_message_converters: custom converters for action spaces and actions.
ChrisCummins marked this conversation as resolved.
Show resolved Hide resolved

:param connection_settings: The settings used to establish a connection
with the remote service.

Expand Down Expand Up @@ -239,9 +263,16 @@ def __init__(
# first reset() call.
pass

self.service_message_converters = (
ServiceMessageConverters()
if service_message_converters is None
else service_message_converters
)

# Process the available action, observation, and reward spaces.
self.action_spaces = [
proto_to_action_space(space) for space in self.service.action_spaces
self.service_message_converters.action_space_converter(space)
for space in self.service.action_spaces
]

self.observation = self._observation_view_type(
Expand Down Expand Up @@ -788,7 +819,9 @@ def _call_with_error(

# If the action space has changed, update it.
if reply.HasField("new_action_space"):
self.action_space = proto_to_action_space(reply.new_action_space)
self.action_space = self.service_message_converters.action_space_converter(
reply.new_action_space
)

self.reward.reset(benchmark=self.benchmark, observation_view=self.observation)
if self.reward_space:
Expand Down Expand Up @@ -857,7 +890,9 @@ def raw_step(
# Send the request to the backend service.
request = StepRequest(
session_id=self._session_id,
action=[Event(int64_value=a) for a in actions],
action=[
self.service_message_converters.action_converter(a) for a in actions
],
observation_space=[
observation_space.index for observation_space in observations_to_compute
],
Expand Down Expand Up @@ -901,7 +936,9 @@ def raw_step(

# If the action space has changed, update it.
if reply.HasField("new_action_space"):
self.action_space = proto_to_action_space(reply.new_action_space)
self.action_space = self.service_message_converters.action_space_converter(
reply.new_action_space
)

# Translate observations to python representations.
if len(reply.observation) != len(observations_to_compute):
Expand Down
2 changes: 0 additions & 2 deletions compiler_gym/service/proto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
CompilerGymServiceServicer,
CompilerGymServiceStub,
)
from compiler_gym.service.proto.py_converters import proto_to_action_space

__all__ = [
"ActionSpace",
Expand Down Expand Up @@ -133,5 +132,4 @@
"StringSequenceSpace",
"StringSpace",
"StringTensor",
"proto_to_action_space",
]
9 changes: 5 additions & 4 deletions compiler_gym/service/proto/py_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@
from compiler_gym.spaces.tuple import Tuple


def proto_to_action_space(space: ActionSpace):
return message_default_converter(space)


class TypeBasedConverter:
"""Converter that dispatches based on the exact type of the parameter.

Expand Down Expand Up @@ -238,6 +234,7 @@ def __init__(self, converter: TypeBasedConverter):
DictEvent: "event_dict",
bool: "boolean_value",
int: "int64_value",
np.int32: "int64_value",
np.float32: "float_value",
float: "double_value",
str: "string_value",
Expand Down Expand Up @@ -375,7 +372,9 @@ def make_message_default_converter() -> TypeBasedConverter:
conversion_map = {
bool: convert_trivial,
int: convert_trivial,
np.int32: convert_trivial,
float: convert_trivial,
np.float32: convert_trivial,
str: convert_trivial,
bytes: convert_bytes_to_numpy,
BooleanTensor: convert_tensor_message_to_numpy,
Expand Down Expand Up @@ -429,7 +428,9 @@ def to_event_message_default_converter() -> ToEventMessageConverter:
conversion_map = {
bool: convert_trivial,
int: convert_trivial,
np.int32: convert_trivial,
float: convert_trivial,
np.float32: convert_trivial,
str: convert_trivial,
np.ndarray: NumpyToTensorMessageConverter(),
}
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/views/observation_space_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def from_proto(cls, index: int, proto: ObservationSpace):
:raises ValueError: If protocol buffer is invalid.
"""
try:
spec = ObservationSpaceSpec.message_converter(proto.space)
spec = ObservationSpaceSpec.message_converter(proto)
except ValueError as e:
raise ValueError(
f"Error interpreting description of observation space '{proto.name}'.\n"
Expand Down