Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions betterproto2/tests/inputs/enum/test_enum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from tests.outputs.enum.enum import (
ArithmeticOperator,
Choice,
Test,
)
Expand Down Expand Up @@ -91,12 +90,3 @@ def test_enum_mapped_on_parse():

# bonus: defaults after empty init are also mapped
assert Test().choice.name == Choice.ZERO.name


def test_renamed_enum_members():
assert set(ArithmeticOperator.__members__) == {
"NONE",
"PLUS",
"MINUS",
"_0_PREFIXED",
}
2 changes: 1 addition & 1 deletion betterproto2/tests/test_all_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def test_all_definition():
"TestSyncStub",
"ThingType",
)
assert enum.__all__ == ("ArithmeticOperator", "Choice", "Test")
assert enum.__all__ == ("ArithmeticOperator", "Choice", "HttpCode", "NoStriping", "Test")
8 changes: 8 additions & 0 deletions betterproto2/tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,11 @@ def test_from_string(member: Colour, input_str: str) -> None:
)
def test_construction(member: Colour, input_int: int) -> None:
assert Colour(input_int) == member


def test_enum_renaming() -> None:
from tests.outputs.enum.enum import ArithmeticOperator, HttpCode, NoStriping

assert set(ArithmeticOperator.__members__) == {"NONE", "PLUS", "MINUS", "_0_PREFIXED"}
assert set(HttpCode.__members__) == {"UNSPECIFIED", "OK", "NOT_FOUND"}
assert set(NoStriping.__members__) == {"NO_STRIPING_NONE", "NO_STRIPING_A", "B"}
31 changes: 28 additions & 3 deletions betterproto2_compiler/src/betterproto2_compiler/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@

from betterproto2 import unwrap

from betterproto2_compiler import casing
from betterproto2_compiler.compile.importing import get_type_reference, parse_source_type_name
from betterproto2_compiler.compile.naming import (
pythonize_class_name,
pythonize_enum_member_name,
pythonize_field_name,
pythonize_method_name,
)
Expand Down Expand Up @@ -614,16 +614,41 @@ class EnumEntry:
comment: str

def __post_init__(self) -> None:
# Get entries/allowed values for this Enum
self.entries = [
self.EnumEntry(
name=pythonize_enum_member_name(entry_proto_value.name, self.proto_obj.name),
name=entry_proto_value.name,
value=entry_proto_value.number,
comment=get_comment(proto_file=self.source_file, path=self.path + [2, entry_number]),
)
for entry_number, entry_proto_value in enumerate(self.proto_obj.value)
]

if not self.entries:
return

# Remove enum prefixes
enum_name: str = self.proto_obj.name

enum_name_reduced = enum_name.replace("_", "").lower()

first_entry = self.entries[0].name

# Find the potential common prefix
enum_prefix = ""
for i in range(len(first_entry)):
if first_entry[: i + 1].replace("_", "").lower() == enum_name_reduced:
enum_prefix = f"{first_entry[: i + 1]}_"
break

should_rename = enum_prefix and all(entry.name.startswith(enum_prefix) for entry in self.entries)

if should_rename:
for entry in self.entries:
entry.name = entry.name[len(enum_prefix) :]

for entry in self.entries:
entry.name = casing.sanitize_name(entry.name)

@property
def proto_name(self) -> str:
return self.proto_obj.name
Expand Down
15 changes: 15 additions & 0 deletions betterproto2_compiler/tests/inputs/enum/enum.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,18 @@ enum ArithmeticOperator {
ARITHMETIC_OPERATOR_MINUS = 2;
ARITHMETIC_OPERATOR_0_PREFIXED = 3;
}

// If not all the fields are prefixed, the prefix should not be stripped at all
enum NoStriping {
NO_STRIPING_NONE = 0;
NO_STRIPING_A = 1;
B = 2;
}

// Make sure that the prefix are removed even if it's difficult to infer the position
// of underscores.
enum HTTPCode {
HTTP_CODE_UNSPECIFIED = 0;
HTTP_CODE_OK = 200;
HTTP_CODE_NOT_FOUND = 404;
}
Loading