Skip to content

Commit

Permalink
Various gRPC refactoring code review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Feb 3, 2022
1 parent 5a1f045 commit 5f7014e
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 37 deletions.
6 changes: 4 additions & 2 deletions compiler_gym/envs/gcc/service/gcc_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ActionSpace,
Benchmark,
ByteSequenceSpace,
ByteTensor,
Event,
Int64Range,
Int64Tensor,
Expand Down Expand Up @@ -141,7 +142,7 @@ def make_gcc_compilation_session(gcc_bin: str):
),
deterministic=True,
platform_dependent=False,
default_observation=Event(bytes_value=b""),
default_observation=Event(byte_tensor=ByteTensor(shape=[0], value=b"")),
),
# The size of the object code
ObservationSpace(
Expand Down Expand Up @@ -512,7 +513,8 @@ def get_observation(self, observation_space: ObservationSpace) -> Event:
elif observation_space.name == "instruction_counts":
return Event(string_value=self.instruction_counts or "{}")
elif observation_space.name == "obj":
return Event(bytes_value=self.obj or b"")
value = self.obj or b""
return Event(byte_tensor=ByteTensor(shape=[len(value)], value=value))
elif observation_space.name == "obj_size":
return Event(int64_value=self.obj_size or -1)
elif observation_space.name == "obj_hash":
Expand Down
17 changes: 7 additions & 10 deletions compiler_gym/service/proto/compiler_gym_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ message DictEvent {
}

// Common structure shared between actions and observations.
// TODO(boian): Maybe rename to Effect
message Event {
string type_id = 1;
oneof value {
Expand All @@ -371,16 +370,14 @@ message Event {
float float_value = 6;
double double_value = 7;
string string_value = 8;
// TODO(boian): maybe remove this because 1D byte tensor can be used instead.
bytes bytes_value = 9;
// Fixed and variable length sequences are represented as one-dimensional tensor.
BooleanTensor boolean_tensor = 10;
ByteTensor byte_tensor = 11;
Int64Tensor int64_tensor = 12;
FloatTensor float_tensor = 13;
DoubleTensor double_tensor = 14;
StringTensor string_tensor = 15;
google.protobuf.Any any_value = 16;
BooleanTensor boolean_tensor = 9;
ByteTensor byte_tensor = 10;
Int64Tensor int64_tensor = 11;
FloatTensor float_tensor = 12;
DoubleTensor double_tensor = 13;
StringTensor string_tensor = 14;
google.protobuf.Any any_value = 15;
}
}

Expand Down
36 changes: 29 additions & 7 deletions compiler_gym/service/proto/py_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""This module contains converters to/from protobuf messages.
For example <compiler_gym.servie.proto.ActionSpace>/<compiler_gym.servie.proto.ObservationSpace> <-> <compiler_gym.spaces>,
or <compiler_gym.servie.proto.Event> <-> actions/observation.
When defining new environments <compiler_gym.service.proto.py_convertes.make_message_default_converter>
and <compiler_gym.service.proto.py_convertes.to_event_message_default_converter>
can be used as a starting point for custom converters.
"""
import json
from builtins import getattr
from typing import Any, Callable
Expand Down Expand Up @@ -66,6 +75,12 @@ def proto_to_action_space(space: ActionSpace):


class TypeBasedConverter:
"""Converter that dispatches based on the exact type of the parameter.
>>> converter = TypeBasedConverter({ int: lambda x: float(x)})
>>> val: float = converter(5)
"""

conversion_map: DictType[Type, Callable[[Any], Any]]

def __init__(self, conversion_map: DictType[Type, Callable[[Any], Any]] = None):
Expand All @@ -75,7 +90,7 @@ def __call__(self, val: Any) -> Any:
return self.conversion_map[type(val)](val)


type_to_dtype_map = {
proto_type_to_dtype_map = {
BooleanTensor: bool,
ByteTensor: np.int8,
Int64Tensor: np.int64,
Expand Down Expand Up @@ -104,7 +119,7 @@ def __call__(self, val: Any) -> Any:
def convert_standard_tensor_message_to_numpy(
tensor: Union[BooleanTensor, Int64Tensor, FloatTensor, DoubleTensor, StringTensor]
):
res = np.array(tensor.value, dtype=type_to_dtype_map[type(tensor)])
res = np.array(tensor.value, dtype=proto_type_to_dtype_map[type(tensor)])
res = res.reshape(tensor.shape)
return res

Expand Down Expand Up @@ -183,9 +198,12 @@ def convert_trivial(val: Any):
return val


# Convert a protobuf message to an object.
# The conversion function is chosen based on the message descriptor.
class FromMessageConverter:
"""Convert a protobuf message to an object.
The conversion function is chosen based on the message descriptor.
"""

conversion_map: DictType[str, Callable[[Message], Any]]

def __init__(self, conversion_map: DictType[str, Callable[[Message], Any]] = None):
Expand Down Expand Up @@ -443,7 +461,9 @@ def convert_range_message(
range_type = type(range)
min = range.min if range.HasField("min") else range_type_default_min_map[range_type]
max = range.max if range.HasField("max") else range_type_default_max_map[range_type]
return Scalar(name=None, min=min, max=max, dtype=type_to_dtype_map[range_type])
return Scalar(
name=None, min=min, max=max, dtype=proto_type_to_dtype_map[range_type]
)


class ToRangeMessageConverter:
Expand Down Expand Up @@ -476,7 +496,7 @@ def convert_box_message(
low=convert_tensor_message_to_numpy(box.low),
high=convert_tensor_message_to_numpy(box.high),
name=None,
dtype=type_to_dtype_map[type(box)],
dtype=proto_type_to_dtype_map[type(box)],
)


Expand Down Expand Up @@ -549,7 +569,7 @@ def convert_sequence_space(
return Sequence(
name=None,
size_range=(length_range.min, length_range.max),
dtype=type_to_dtype_map[type(seq)],
dtype=proto_type_to_dtype_map[type(seq)],
scalar_range=scalar_range,
)

Expand Down Expand Up @@ -781,6 +801,8 @@ def to_space_message_default_converter() -> ToSpaceMessageConverter:


class OpaqueMessageConverter:
"""Converts <compiler_gym.service.proto.Opaque> message based on its format descriptor."""

format_coverter_map: DictType[str, Callable[[bytes], Any]]

def __init__(self, format_coverter_map=None):
Expand Down
8 changes: 2 additions & 6 deletions compiler_gym/views/observation_space_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
# import numpy as np
from gym.spaces import Space

from compiler_gym.service.proto import ( # DoubleRange as ScalarRange,
Event,
ObservationSpace,
py_converters,
)
from compiler_gym.service.proto import Event, ObservationSpace, py_converters
from compiler_gym.util.gym_type_hints import ObservationType


Expand Down Expand Up @@ -102,7 +98,7 @@ def from_proto(cls, index: int, proto: ObservationSpace):
return cls(
id=proto.name,
index=index,
space=ObservationSpaceSpec.message_converter(proto),
space=ObservationSpaceSpec.message_converter(proto.space),
translate=ObservationSpaceSpec.message_converter,
to_string=str,
deterministic=proto.deterministic,
Expand Down
85 changes: 74 additions & 11 deletions docs/source/rpc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,40 +95,103 @@ Core Message Types
.. doxygenstruct:: ActionSpace
:members:

.. doxygenstruct:: Action
.. doxygenstruct:: ObservationSpace
:members:

.. doxygenstruct:: ChoiceSpace
.. doxygenstruct:: Event
:members:

.. doxygenstruct:: Choice
.. doxygenstruct:: BooleanTensor
:members:

.. doxygenstruct:: ObservationSpace
.. doxygenstruct:: ByteTensor
:members:

.. doxygenstruct:: Int64Tensor
:members:

.. doxygenstruct:: FloatTensor
:members:

.. doxygenstruct:: DoubleTensor
:members:

.. doxygenstruct:: StringTensor
:members:

.. doxygenstruct:: BooleanRange
:members:

.. doxygenstruct:: Int64Range
:members:

.. doxygenstruct:: FloatRange
:members:

.. doxygenstruct:: DoubleRange
:members:

.. doxygenstruct:: BooleanBox
:members:

.. doxygenstruct:: ByteBox
:members:

.. doxygenstruct:: Observation
.. doxygenstruct:: Int64Box
:members:

.. doxygenstruct:: FloatBox
:members:

.. doxygenstruct:: DoubleBox
:members:

.. doxygenstruct:: ListSpace
:members:

.. doxygenstruct:: DictSpace
:members:

.. doxygenstruct:: DiscreteSpace
:members:

.. doxygenstruct:: NamedDiscreteSpace
:members:

.. doxygenstruct:: Int64List
.. doxygenstruct:: BooleanSequenceSpace
:members:

.. doxygenstruct:: ByteSequenceSpace
:members:

.. doxygenstruct:: BytesSequenceSpace
:members:

.. doxygenstruct:: Int64SequenceSpace
:members:

.. doxygenstruct:: FloatSequenceSpace
:members:

.. doxygenstruct:: DoubleSequenceSpace
:members:

.. doxygenstruct:: StringSequenceSpace
:members:

.. doxygenstruct:: DoubleList
.. doxygenstruct:: StringSpace
:members:

.. doxygenstruct:: ScalarRange
.. doxygenstruct:: Opaque
:members:

.. doxygenstruct:: ScalarLimit
.. doxygenstruct:: CommandlineSpace
:members:

.. doxygenstruct:: ScalarRangeList
.. doxygenstruct:: ListEvent
:members:

.. doxygenstruct:: SequenceSpace
.. doxygenstruct:: DictEvent
:members:

.. doxygenstruct:: Benchmark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""An example CompilerGym service in python."""
import logging
import os
Expand Down

0 comments on commit 5f7014e

Please sign in to comment.