From c3cae40c0114edbabbb719abb1d6b653157686e2 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Sat, 5 Mar 2022 07:52:15 -0800 Subject: [PATCH 1/3] Make action spaces and action service message conversion configurable in ClientServiceCompilerEnv --- .../service/client_service_compiler_env.py | 51 ++++++++++++++++--- compiler_gym/service/proto/__init__.py | 2 - compiler_gym/service/proto/py_converters.py | 4 -- compiler_gym/views/observation_space_spec.py | 2 +- 4 files changed, 45 insertions(+), 14 deletions(-) diff --git a/compiler_gym/service/client_service_compiler_env.py b/compiler_gym/service/client_service_compiler_env.py index 9b1db32e3..fd51077c6 100644 --- a/compiler_gym/service/client_service_compiler_env.py +++ b/compiler_gym/service/client_service_compiler_env.py @@ -12,7 +12,7 @@ from math import isclose from pathlib import Path from time import time -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np from deprecated.sphinx import deprecated @@ -31,7 +31,7 @@ SessionNotFound, ) from compiler_gym.service.connection import ServiceIsClosed -from compiler_gym.service.proto import AddBenchmarkRequest +from compiler_gym.service.proto import ActionSpace, AddBenchmarkRequest from compiler_gym.service.proto import Benchmark as BenchmarkProto from compiler_gym.service.proto import ( EndSessionReply, @@ -47,7 +47,7 @@ StartSessionRequest, StepReply, StepRequest, - proto_to_action_space, + py_converters, ) from compiler_gym.spaces import DefaultRewardFromObservation, NamedDiscrete, Reward from compiler_gym.util.gym_type_hints import ( @@ -82,6 +82,27 @@ def _wrapped_step( raise +class ServiceMessageConverters: + action_space_converter: Callable[[ActionSpace], Space] + action_converter: Callable[[ActionType], Event] + + def __init__( + self, + action_space_converter: Optional[Callable[[ActionSpace], Space]] = None, + action_converter: Optional[Callable[[Any], Event]] = None, + ): + self.action_space_converter = ( + py_converters.make_message_default_converter() + if action_space_converter is None + else action_space_converter + ) + self.action_converter = ( + py_converters.to_event_message_default_converter() + if action_converter is None + else action_converter + ) + + class ClientServiceCompilerEnv(CompilerEnv): """Implementation using gRPC for a client-server communication. @@ -106,6 +127,7 @@ def __init__( reward_space: Optional[Union[str, Reward]] = None, action_space: Optional[str] = None, derived_observation_spaces: Optional[List[Dict[str, Any]]] = None, + service_message_converters: ServiceMessageConverters = None, connection_settings: Optional[ConnectionOpts] = None, service_connection: Optional[CompilerGymServiceConnection] = None, logger: Optional[logging.Logger] = None, @@ -156,6 +178,8 @@ def __init__( passed to :meth:`env.observation.add_derived_space() `. + :param service_message_converters: custom converters for action spaces and actions. + :param connection_settings: The settings used to establish a connection with the remote service. @@ -239,9 +263,16 @@ def __init__( # first reset() call. pass + self.service_message_converters = ( + ServiceMessageConverters() + if service_message_converters is None + else service_message_converters + ) + # Process the available action, observation, and reward spaces. self.action_spaces = [ - proto_to_action_space(space) for space in self.service.action_spaces + self.service_message_converters.action_space_converter(space) + for space in self.service.action_spaces ] self.observation = self._observation_view_type( @@ -788,7 +819,9 @@ def _call_with_error( # If the action space has changed, update it. if reply.HasField("new_action_space"): - self.action_space = proto_to_action_space(reply.new_action_space) + self.action_space = self.service_message_converters.action_space_converter( + reply.new_action_space + ) self.reward.reset(benchmark=self.benchmark, observation_view=self.observation) if self.reward_space: @@ -857,7 +890,9 @@ def raw_step( # Send the request to the backend service. request = StepRequest( session_id=self._session_id, - action=[Event(int64_value=a) for a in actions], + action=[ + self.service_message_converters.action_converter(a) for a in actions + ], observation_space=[ observation_space.index for observation_space in observations_to_compute ], @@ -901,7 +936,9 @@ def raw_step( # If the action space has changed, update it. if reply.HasField("new_action_space"): - self.action_space = proto_to_action_space(reply.new_action_space) + self.action_space = self.service_message_converters.action_space_converter( + reply.new_action_space + ) # Translate observations to python representations. if len(reply.observation) != len(observations_to_compute): diff --git a/compiler_gym/service/proto/__init__.py b/compiler_gym/service/proto/__init__.py index d863bfd4e..b85ad85a7 100644 --- a/compiler_gym/service/proto/__init__.py +++ b/compiler_gym/service/proto/__init__.py @@ -64,7 +64,6 @@ CompilerGymServiceServicer, CompilerGymServiceStub, ) -from compiler_gym.service.proto.py_converters import proto_to_action_space __all__ = [ "ActionSpace", @@ -133,5 +132,4 @@ "StringSequenceSpace", "StringSpace", "StringTensor", - "proto_to_action_space", ] diff --git a/compiler_gym/service/proto/py_converters.py b/compiler_gym/service/proto/py_converters.py index 42ec57927..edfa8090e 100644 --- a/compiler_gym/service/proto/py_converters.py +++ b/compiler_gym/service/proto/py_converters.py @@ -70,10 +70,6 @@ from compiler_gym.spaces.tuple import Tuple -def proto_to_action_space(space: ActionSpace): - return message_default_converter(space) - - class TypeBasedConverter: """Converter that dispatches based on the exact type of the parameter. diff --git a/compiler_gym/views/observation_space_spec.py b/compiler_gym/views/observation_space_spec.py index 77b50c9fd..c00c8cc4b 100644 --- a/compiler_gym/views/observation_space_spec.py +++ b/compiler_gym/views/observation_space_spec.py @@ -104,7 +104,7 @@ def from_proto(cls, index: int, proto: ObservationSpace): :raises ValueError: If protocol buffer is invalid. """ try: - spec = ObservationSpaceSpec.message_converter(proto.space) + spec = ObservationSpaceSpec.message_converter(proto) except ValueError as e: raise ValueError( f"Error interpreting description of observation space '{proto.name}'.\n" From 86e69f0a356974be0deefd37ff0b07663b727f40 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 30 Mar 2022 17:11:59 -0700 Subject: [PATCH 2/3] Add trivial conversion of np.int32 and np.float32 in service message conversion --- compiler_gym/service/proto/py_converters.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/compiler_gym/service/proto/py_converters.py b/compiler_gym/service/proto/py_converters.py index edfa8090e..b53c9d0dc 100644 --- a/compiler_gym/service/proto/py_converters.py +++ b/compiler_gym/service/proto/py_converters.py @@ -234,6 +234,7 @@ def __init__(self, converter: TypeBasedConverter): DictEvent: "event_dict", bool: "boolean_value", int: "int64_value", + np.int32: "int64_value", np.float32: "float_value", float: "double_value", str: "string_value", @@ -371,7 +372,9 @@ def make_message_default_converter() -> TypeBasedConverter: conversion_map = { bool: convert_trivial, int: convert_trivial, + np.int32: convert_trivial, float: convert_trivial, + np.float32: convert_trivial, str: convert_trivial, bytes: convert_bytes_to_numpy, BooleanTensor: convert_tensor_message_to_numpy, @@ -425,7 +428,9 @@ def to_event_message_default_converter() -> ToEventMessageConverter: conversion_map = { bool: convert_trivial, int: convert_trivial, + np.int32: convert_trivial, float: convert_trivial, + np.float32: convert_trivial, str: convert_trivial, np.ndarray: NumpyToTensorMessageConverter(), } From 06696c7284234d0098b2162e8008f04c93a556f9 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Fri, 1 Apr 2022 19:40:38 -0700 Subject: [PATCH 3/3] Add docstring for ServiceMessageConverters --- compiler_gym/service/client_service_compiler_env.py | 10 +++++++++- docs/source/compiler_gym/service.rst | 9 +++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/compiler_gym/service/client_service_compiler_env.py b/compiler_gym/service/client_service_compiler_env.py index fd51077c6..2165b8ff9 100644 --- a/compiler_gym/service/client_service_compiler_env.py +++ b/compiler_gym/service/client_service_compiler_env.py @@ -83,6 +83,14 @@ def _wrapped_step( class ServiceMessageConverters: + """Allows for customization of conversion to/from gRPC messages for the + . + + Supports conversion customizations: + * -> . + * -> . + """ + action_space_converter: Callable[[ActionSpace], Space] action_converter: Callable[[ActionType], Event] @@ -178,7 +186,7 @@ def __init__( passed to :meth:`env.observation.add_derived_space() `. - :param service_message_converters: custom converters for action spaces and actions. + :param service_message_converters: Custom converters for action spaces and actions. :param connection_settings: The settings used to establish a connection with the remote service. diff --git a/docs/source/compiler_gym/service.rst b/docs/source/compiler_gym/service.rst index 32c93026d..b083ce935 100644 --- a/docs/source/compiler_gym/service.rst +++ b/docs/source/compiler_gym/service.rst @@ -14,6 +14,15 @@ client and service is managed by the :class:`CompilerGymServiceConnection :local: +ServiceMessageConverters +------------------------ + +.. autoclass:: ClientServiceCompilerEnv + :members: + + .. automethod:: __init__ + + ClientServiceCompilerEnv ------------------------