Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion betterproto2/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ build-backend = "hatchling.build"
# ]

[tool.ruff]
extend-exclude = ["tests/output_*", "src/betterproto2/internal_lib"]
extend-exclude = ["tests/outputs", "src/betterproto2/internal_lib"]
target-version = "py310"
line-length = 120

Expand Down
7 changes: 4 additions & 3 deletions betterproto2/src/betterproto2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ._types import T
from ._version import __version__, check_compiler_version
from .casing import camel_case, safe_snake_case, snake_case
from .enum import Enum as Enum
from .enum_ import Enum as Enum
from .grpc.grpclib_client import ServiceStub as ServiceStub
from .utils import classproperty

Expand Down Expand Up @@ -585,9 +585,10 @@ def _value_to_dict(
if proto_type in INT_64_TYPES:
return str(value), not bool(value)
if proto_type == TYPE_BYTES:
return b64encode(value).decode("utf8"), not (bool(value))
return b64encode(value).decode("utf8"), not bool(value)
if proto_type == TYPE_ENUM:
return field_type(value).name, not bool(value)
enum_value = field_type(value)
return enum_value.proto_name or enum_value.name, not bool(value)
if proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
return _dump_float(value), not bool(value)
return value, not bool(value)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,42 @@
from enum import IntEnum
import sys
from enum import EnumMeta, IntEnum

from typing_extensions import Self


class Enum(IntEnum):
class _EnumMeta(EnumMeta):
def __new__(metacls, cls, bases, classdict):
# Find the proto names if defined
if sys.version_info >= (3, 11):
proto_names = classdict.pop("betterproto_proto_names", {})
classdict._member_names.pop("betterproto_proto_names", None)
else:
proto_names = {}
if "betterproto_proto_names" in classdict:
proto_names = classdict.pop("betterproto_proto_names")
classdict._member_names.remove("betterproto_proto_names")

enum_class = super().__new__(metacls, cls, bases, classdict)

# Attach extra info to each enum member
for member in enum_class:
value = member.value # type: ignore[reportAttributeAccessIssue]
extra = proto_names.get(value)
member._proto_name = extra # type: ignore[reportAttributeAccessIssue]

return enum_class


class Enum(IntEnum, metaclass=_EnumMeta):
@property
def proto_name(self) -> str | None:
return self._proto_name # type: ignore[reportAttributeAccessIssue]

@classmethod
def _missing_(cls, value):
# If the given value is not an integer, let the standard enum implementation raise an error
if not isinstance(value, int):
return None
return

# Create a new "unknown" instance with the given value.
obj = int.__new__(cls, value)
Expand Down
2 changes: 1 addition & 1 deletion betterproto2/tests/test_all_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def test_all_definition():
"TestSyncStub",
"ThingType",
)
assert enum.__all__ == ("ArithmeticOperator", "Choice", "HttpCode", "NoStriping", "Test")
assert enum.__all__ == ("ArithmeticOperator", "Choice", "EnumMessage", "HttpCode", "NoStriping", "Test")
16 changes: 16 additions & 0 deletions betterproto2/tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,19 @@ def test_enum_renaming() -> None:
assert set(ArithmeticOperator.__members__) == {"NONE", "PLUS", "MINUS", "_0_PREFIXED"}
assert set(HttpCode.__members__) == {"UNSPECIFIED", "OK", "NOT_FOUND"}
assert set(NoStriping.__members__) == {"NO_STRIPING_NONE", "NO_STRIPING_A", "B"}


def test_enum_to_dict() -> None:
from tests.outputs.enum.enum import ArithmeticOperator, EnumMessage, NoStriping

msg = EnumMessage(
arithmetic_operator=ArithmeticOperator.PLUS,
no_striping=NoStriping.NO_STRIPING_A,
)

print(ArithmeticOperator.PLUS.proto_name)

assert msg.to_dict() == {
"arithmeticOperator": "ARITHMETIC_OPERATOR_PLUS", # The original proto name must be preserved
"noStriping": "NO_STRIPING_A",
}
2 changes: 1 addition & 1 deletion betterproto2_compiler/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.ruff]
extend-exclude = ["tests/output_*", "src/betterproto2_compiler/lib"]
extend-exclude = ["tests/outputs", "src/betterproto2_compiler/lib"]
target-version = "py310"
line-length = 120

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,13 +610,15 @@ class EnumEntry:
"""Representation of an Enum entry."""

name: str
proto_name: str
value: int
comment: str

def __post_init__(self) -> None:
self.entries = [
self.EnumEntry(
name=entry_proto_value.name,
proto_name=entry_proto_value.name,
value=entry_proto_value.number,
comment=get_comment(proto_file=self.source_file, path=self.path + [2, entry_number]),
)
Expand Down Expand Up @@ -672,6 +674,10 @@ def descriptor_name(self) -> str:
"""
return self.output_file.get_descriptor_name(self.source_file)

@property
def has_renamed_entries(self) -> bool:
return any(entry.proto_name != entry.name for entry in self.entries)


@dataclass(kw_only=True)
class ServiceCompiler(ProtoContentBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ class {{ enum.py_name | add_to_all }}(betterproto2.Enum):
return core_schema.int_schema(ge=0)
{% endif %}

{% if enum.has_renamed_entries %}
betterproto_proto_names = {
{% for entry in enum.entries %}
{% if entry.proto_name != entry.name %}
{{ entry.value }}: "{{ entry.proto_name }}",
{% endif %}
{% endfor %}
}
{% endif %}

{% endfor %}
{% for _, message in output_file.messages|dictsort(by="key") %}
{% if output_file.settings.pydantic_dataclasses %}
Expand Down
19 changes: 12 additions & 7 deletions betterproto2_compiler/tests/inputs/enum/enum.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ package enum;

// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values
message Test {
Choice choice = 1;
repeated Choice choices = 2;
Choice choice = 1;
repeated Choice choices = 2;
}

enum Choice {
ZERO = 0;
ONE = 1;
// TWO = 2;
FOUR = 4;
THREE = 3;
ZERO = 0;
ONE = 1;
// TWO = 2;
FOUR = 4;
THREE = 3;
}

// A "C" like enum with the enum name prefixed onto members, these should be stripped
Expand All @@ -38,3 +38,8 @@ enum HTTPCode {
HTTP_CODE_OK = 200;
HTTP_CODE_NOT_FOUND = 404;
}

message EnumMessage {
ArithmeticOperator arithmetic_operator = 1;
NoStriping no_striping = 2;
}
Loading