Skip to content

Commit

Permalink
feat: allow enum strings in json serialization and deserialization (#107
Browse files Browse the repository at this point in the history
)

* feat: allow enum strings in json serialization and deserialization

For protobuf messages that contain enum fields, it is now possible to
specify that enum variants should be serialized as names and not as integers.

E.g.

json_str = MyMessage.to_json(my_message, enum_strings=True)

Similarly, serialization from json that uses this convention is now supported.

This is useful for interoperation with other data sources that do use
strings to define enum variants in json serialization; and for
debugging, where visually inspecting data structures can be helpful,
and variant names are more informative than numerical values.

Note: includes reformatting of many source files due to an update to Black
  • Loading branch information
software-dov committed Aug 28, 2020
1 parent 310dc18 commit a082f85
Show file tree
Hide file tree
Showing 24 changed files with 332 additions and 79 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ jobs:
- run:
name: Format files
command: |
black .
black -l 88 .
- run:
name: Check diff
command: |
Expand Down
9 changes: 9 additions & 0 deletions docs/messages.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,12 @@ via the :meth:`~.Message.to_json` and :meth:`~.Message.from_json` methods.
new_song = Song.from_json(json)
The behavior of JSON serialization can be customized to use strings to
represent enum values.

.. code-block:: python
song = Song(genre=Genre.JAZZ)
json = Song.to_json(song, use_integers_for_enums=False)
assert "JAZZ" in json
25 changes: 24 additions & 1 deletion proto/_file_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections.abc
import collections
import inspect
import logging

from google.protobuf import descriptor_pb2
from google.protobuf import descriptor_pool
from google.protobuf import message
from google.protobuf import reflection
Expand All @@ -32,6 +33,28 @@ class _FileInfo(
):
registry = {} # Mapping[str, '_FileInfo']

@classmethod
def maybe_add_descriptor(cls, filename, package):
descriptor = cls.registry.get(filename)
if not descriptor:
descriptor = cls.registry[filename] = cls(
descriptor=descriptor_pb2.FileDescriptorProto(
name=filename,
package=package,
syntax="proto3",
),
enums=collections.OrderedDict(),
messages=collections.OrderedDict(),
name=filename,
nested={},
)

return descriptor

@staticmethod
def proto_file_name(name):
return "{0}.proto".format(name).replace(".", "/")

def _get_manifest(self, new_class):
module = inspect.getmodule(new_class)
if hasattr(module, "__protobuf__"):
Expand Down
49 changes: 49 additions & 0 deletions proto/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import enum

from google.protobuf import descriptor_pb2

from proto import _file_info
from proto import _package_info
from proto.marshal.rules.enums import EnumRule

Expand All @@ -30,9 +33,49 @@ def __new__(mcls, name, bases, attrs):
# this component belongs within the file.
package, marshal = _package_info.compile(name, attrs)

# Determine the local path of this proto component within the file.
local_path = tuple(attrs.get("__qualname__", name).split("."))

# Sanity check: We get the wrong full name if a class is declared
# inside a function local scope; correct this.
if "<locals>" in local_path:
ix = local_path.index("<locals>")
local_path = local_path[: ix - 1] + local_path[ix + 1 :]

# Determine the full name in protocol buffers.
# The C++ proto implementation doesn't like dots in names, so use underscores.
full_name = "_".join((package,) + local_path).lstrip("_")
enum_desc = descriptor_pb2.EnumDescriptorProto(
name=full_name,
# Note: the superclass ctor removes the variants, so get them now.
# Note: proto3 requires that the first variant value be zero.
value=sorted(
(
descriptor_pb2.EnumValueDescriptorProto(name=name, number=number)
# Minor hack to get all the enum variants out.
for name, number in attrs.items()
if isinstance(number, int)
),
key=lambda v: v.number,
),
)

filename = _file_info._FileInfo.proto_file_name(
attrs.get("__module__", name.lower())
)

file_info = _file_info._FileInfo.maybe_add_descriptor(filename, package)
file_info.descriptor.enum_type.add().MergeFrom(enum_desc)

# Run the superclass constructor.
cls = super().__new__(mcls, name, bases, attrs)

# We can't just add a "_meta" element to attrs because the Enum
# machinery doesn't know what to do with a non-int value.
cls._meta = _EnumInfo(full_name=full_name, pb=enum_desc)

file_info.enums[full_name] = cls

# Register the enum with the marshal.
marshal.register(cls, EnumRule(cls))

Expand All @@ -44,3 +87,9 @@ class Enum(enum.IntEnum, metaclass=ProtoEnumMeta):
"""A enum object that also builds a protobuf enum descriptor."""

pass


class _EnumInfo:
def __init__(self, *, full_name: str, pb):
self.full_name = full_name
self.pb = pb
17 changes: 3 additions & 14 deletions proto/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def descriptor(self):
if isinstance(self.message, str):
if not self.message.startswith(self.package):
self.message = "{package}.{name}".format(
package=self.package, name=self.message,
package=self.package,
name=self.message,
)
type_name = self.message
elif self.message:
Expand All @@ -88,19 +89,7 @@ def descriptor(self):
else self.message.meta.full_name
)
elif self.enum:
# Nos decipiat.
#
# As far as the wire format is concerned, enums are int32s.
# Protocol buffers itself also only sends ints; the enum
# objects are simply helper classes for translating names
# and values and it is the user's job to resolve to an int.
#
# Therefore, the non-trivial effort of adding the actual
# enum descriptors seems to add little or no actual value.
#
# FIXME: Eventually, come back and put in the actual enum
# descriptors.
proto_type = ProtoType.INT32
type_name = self.enum._meta.full_name

# Set the descriptor.
self._descriptor = descriptor_pb2.FieldDescriptorProto(
Expand Down
3 changes: 2 additions & 1 deletion proto/marshal/marshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def to_proto(self, proto_type, value, *, strict: bool = False):
raise TypeError(
"Parameter must be instance of the same class; "
"expected {expected}, got {got}".format(
expected=proto_type.__name__, got=pb_value.__class__.__name__,
expected=proto_type.__name__,
got=pb_value.__class__.__name__,
),
)

Expand Down
3 changes: 2 additions & 1 deletion proto/marshal/rules/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def to_proto(self, value) -> timestamp_pb2.Timestamp:
return value.timestamp_pb()
if isinstance(value, datetime):
return timestamp_pb2.Timestamp(
seconds=int(value.timestamp()), nanos=value.microsecond * 1000,
seconds=int(value.timestamp()),
nanos=value.microsecond * 1000,
)
return value

Expand Down
3 changes: 2 additions & 1 deletion proto/marshal/rules/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def to_python(self, value, *, absent: bool = None):
# the user realizes that an unexpected value came along.
warnings.warn(
"Unrecognized {name} enum value: {value}".format(
name=self._enum.__name__, value=value,
name=self._enum.__name__,
value=value,
)
)
return value
Expand Down
12 changes: 9 additions & 3 deletions proto/marshal/rules/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,15 @@ def to_python(self, value, *, absent: bool = None):
return str(value.string_value)
if kind == "struct_value":
return self._marshal.to_python(
struct_pb2.Struct, value.struct_value, absent=False,
struct_pb2.Struct,
value.struct_value,
absent=False,
)
if kind == "list_value":
return self._marshal.to_python(
struct_pb2.ListValue, value.list_value, absent=False,
struct_pb2.ListValue,
value.list_value,
absent=False,
)
raise AttributeError

Expand Down Expand Up @@ -114,7 +118,9 @@ def to_proto(self, value) -> struct_pb2.Struct:
if isinstance(value, struct_pb2.Struct):
return value
if isinstance(value, maps.MapComposite):
return struct_pb2.Struct(fields={k: v for k, v in value.pb.items()},)
return struct_pb2.Struct(
fields={k: v for k, v in value.pb.items()},
)

# We got a dict (or something dict-like); convert it.
answer = struct_pb2.Struct(
Expand Down
57 changes: 35 additions & 22 deletions proto/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def __new__(mcls, name, bases, attrs):
# Determine the name of the entry message.
msg_name = "{pascal_key}Entry".format(
pascal_key=re.sub(
r"_\w", lambda m: m.group()[1:].upper(), key,
r"_\w",
lambda m: m.group()[1:].upper(),
key,
).replace(key[0], key[0].upper(), 1),
)

Expand All @@ -84,20 +86,26 @@ def __new__(mcls, name, bases, attrs):
{
"__module__": attrs.get("__module__", None),
"__qualname__": "{prefix}.{name}".format(
prefix=attrs.get("__qualname__", name), name=msg_name,
prefix=attrs.get("__qualname__", name),
name=msg_name,
),
"_pb_options": {"map_entry": True},
}
)
entry_attrs["key"] = Field(field.map_key_type, number=1)
entry_attrs["value"] = Field(
field.proto_type, number=2, enum=field.enum, message=field.message,
field.proto_type,
number=2,
enum=field.enum,
message=field.message,
)
map_fields[msg_name] = MessageMeta(msg_name, (Message,), entry_attrs)

# Create the repeated field for the entry message.
map_fields[key] = RepeatedField(
ProtoType.MESSAGE, number=field.number, message=map_fields[msg_name],
ProtoType.MESSAGE,
number=field.number,
message=map_fields[msg_name],
)

# Add the new entries to the attrs
Expand Down Expand Up @@ -183,24 +191,13 @@ def __new__(mcls, name, bases, attrs):
# Determine the filename.
# We determine an appropriate proto filename based on the
# Python module.
filename = "{0}.proto".format(
new_attrs.get("__module__", name.lower()).replace(".", "/")
filename = _file_info._FileInfo.proto_file_name(
new_attrs.get("__module__", name.lower())
)

# Get or create the information about the file, including the
# descriptor to which the new message descriptor shall be added.
file_info = _file_info._FileInfo.registry.setdefault(
filename,
_file_info._FileInfo(
descriptor=descriptor_pb2.FileDescriptorProto(
name=filename, package=package, syntax="proto3",
),
enums=collections.OrderedDict(),
messages=collections.OrderedDict(),
name=filename,
nested={},
),
)
file_info = _file_info._FileInfo.maybe_add_descriptor(filename, package)

# Ensure any imports that would be necessary are assigned to the file
# descriptor proto being created.
Expand Down Expand Up @@ -286,7 +283,13 @@ def pb(cls, obj=None, *, coerce: bool = False):
if coerce:
obj = cls(obj)
else:
raise TypeError("%r is not an instance of %s" % (obj, cls.__name__,))
raise TypeError(
"%r is not an instance of %s"
% (
obj,
cls.__name__,
)
)
return obj._pb

def wrap(cls, pb):
Expand Down Expand Up @@ -325,17 +328,24 @@ def deserialize(cls, payload: bytes) -> "Message":
"""
return cls.wrap(cls.pb().FromString(payload))

def to_json(cls, instance) -> str:
def to_json(cls, instance, *, use_integers_for_enums=True) -> str:
"""Given a message instance, serialize it to json
Args:
instance: An instance of this message type, or something
compatible (accepted by the type's constructor).
use_integers_for_enums (Optional(bool)): An option that determines whether enum
values should be represented by strings (False) or integers (True).
Default is True.
Returns:
str: The json string representation of the protocol buffer.
"""
return MessageToJson(cls.pb(instance))
return MessageToJson(
cls.pb(instance),
use_integers_for_enums=use_integers_for_enums,
including_default_value_fields=True,
)

def from_json(cls, payload) -> "Message":
"""Given a json string representing an instance,
Expand Down Expand Up @@ -399,7 +409,10 @@ def __init__(self, mapping=None, **kwargs):
# Sanity check: Did we get something not a map? Error if so.
raise TypeError(
"Invalid constructor input for %s: %r"
% (self.__class__.__name__, mapping,)
% (
self.__class__.__name__,
mapping,
)
)
else:
# Can't have side effects on mapping.
Expand Down
9 changes: 7 additions & 2 deletions proto/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@


_ProtoModule = collections.namedtuple(
"ProtoModule", ["package", "marshal", "manifest"],
"ProtoModule",
["package", "marshal", "manifest"],
)


Expand All @@ -39,7 +40,11 @@ def define_module(
"""
if not marshal:
marshal = package
return _ProtoModule(package=package, marshal=marshal, manifest=frozenset(manifest),)
return _ProtoModule(
package=package,
marshal=marshal,
manifest=frozenset(manifest),
)


__all__ = ("define_module",)
4 changes: 3 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def _register_messages(scope, iterable, sym_db):
"""Create and register messages from the file descriptor."""
for name, descriptor in iterable.items():
new_msg = reflection.GeneratedProtocolMessageType(
name, (message.Message,), {"DESCRIPTOR": descriptor, "__module__": None},
name,
(message.Message,),
{"DESCRIPTOR": descriptor, "__module__": None},
)
sym_db.RegisterMessage(new_msg)
setattr(scope, name, new_msg)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_fields_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,16 @@ class Foo(proto.Message):
assert "color" not in foo
assert not foo.color
assert Foo.pb(foo).color == 0


class Zone(proto.Enum):
EPIPELAGIC = 0
MESOPELAGIC = 1
ABYSSOPELAGIC = 2
HADOPELAGIC = 3


def test_enum_outest():
z = Zone(value=Zone.MESOPELAGIC)

assert z == Zone.MESOPELAGIC

0 comments on commit a082f85

Please sign in to comment.