diff --git a/betterproto2/tests/inputs/enum/test_enum.py b/betterproto2/tests/inputs/enum/test_enum.py index 6f7f6c9e..47443ec1 100644 --- a/betterproto2/tests/inputs/enum/test_enum.py +++ b/betterproto2/tests/inputs/enum/test_enum.py @@ -1,5 +1,4 @@ from tests.outputs.enum.enum import ( - ArithmeticOperator, Choice, Test, ) @@ -91,12 +90,3 @@ def test_enum_mapped_on_parse(): # bonus: defaults after empty init are also mapped assert Test().choice.name == Choice.ZERO.name - - -def test_renamed_enum_members(): - assert set(ArithmeticOperator.__members__) == { - "NONE", - "PLUS", - "MINUS", - "_0_PREFIXED", - } diff --git a/betterproto2/tests/test_all_definition.py b/betterproto2/tests/test_all_definition.py index 26304f5d..92ed50b2 100644 --- a/betterproto2/tests/test_all_definition.py +++ b/betterproto2/tests/test_all_definition.py @@ -17,4 +17,4 @@ def test_all_definition(): "TestSyncStub", "ThingType", ) - assert enum.__all__ == ("ArithmeticOperator", "Choice", "Test") + assert enum.__all__ == ("ArithmeticOperator", "Choice", "HttpCode", "NoStriping", "Test") diff --git a/betterproto2/tests/test_enum.py b/betterproto2/tests/test_enum.py index d6e833f2..b2dcfed2 100644 --- a/betterproto2/tests/test_enum.py +++ b/betterproto2/tests/test_enum.py @@ -74,3 +74,11 @@ def test_from_string(member: Colour, input_str: str) -> None: ) def test_construction(member: Colour, input_int: int) -> None: assert Colour(input_int) == member + + +def test_enum_renaming() -> None: + from tests.outputs.enum.enum import ArithmeticOperator, HttpCode, NoStriping + + assert set(ArithmeticOperator.__members__) == {"NONE", "PLUS", "MINUS", "_0_PREFIXED"} + assert set(HttpCode.__members__) == {"UNSPECIFIED", "OK", "NOT_FOUND"} + assert set(NoStriping.__members__) == {"NO_STRIPING_NONE", "NO_STRIPING_A", "B"} diff --git a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py index 5b339735..953aba56 100644 --- a/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py +++ b/betterproto2_compiler/src/betterproto2_compiler/plugin/models.py @@ -34,10 +34,10 @@ from betterproto2 import unwrap +from betterproto2_compiler import casing from betterproto2_compiler.compile.importing import get_type_reference, parse_source_type_name from betterproto2_compiler.compile.naming import ( pythonize_class_name, - pythonize_enum_member_name, pythonize_field_name, pythonize_method_name, ) @@ -614,16 +614,41 @@ class EnumEntry: comment: str def __post_init__(self) -> None: - # Get entries/allowed values for this Enum self.entries = [ self.EnumEntry( - name=pythonize_enum_member_name(entry_proto_value.name, self.proto_obj.name), + name=entry_proto_value.name, value=entry_proto_value.number, comment=get_comment(proto_file=self.source_file, path=self.path + [2, entry_number]), ) for entry_number, entry_proto_value in enumerate(self.proto_obj.value) ] + if not self.entries: + return + + # Remove enum prefixes + enum_name: str = self.proto_obj.name + + enum_name_reduced = enum_name.replace("_", "").lower() + + first_entry = self.entries[0].name + + # Find the potential common prefix + enum_prefix = "" + for i in range(len(first_entry)): + if first_entry[: i + 1].replace("_", "").lower() == enum_name_reduced: + enum_prefix = f"{first_entry[: i + 1]}_" + break + + should_rename = enum_prefix and all(entry.name.startswith(enum_prefix) for entry in self.entries) + + if should_rename: + for entry in self.entries: + entry.name = entry.name[len(enum_prefix) :] + + for entry in self.entries: + entry.name = casing.sanitize_name(entry.name) + @property def proto_name(self) -> str: return self.proto_obj.name diff --git a/betterproto2_compiler/tests/inputs/enum/enum.proto b/betterproto2_compiler/tests/inputs/enum/enum.proto index 5e2e80c1..d37133a6 100644 --- a/betterproto2_compiler/tests/inputs/enum/enum.proto +++ b/betterproto2_compiler/tests/inputs/enum/enum.proto @@ -23,3 +23,18 @@ enum ArithmeticOperator { ARITHMETIC_OPERATOR_MINUS = 2; ARITHMETIC_OPERATOR_0_PREFIXED = 3; } + +// If not all the fields are prefixed, the prefix should not be stripped at all +enum NoStriping { + NO_STRIPING_NONE = 0; + NO_STRIPING_A = 1; + B = 2; +} + +// Make sure that the prefix are removed even if it's difficult to infer the position +// of underscores. +enum HTTPCode { + HTTP_CODE_UNSPECIFIED = 0; + HTTP_CODE_OK = 200; + HTTP_CODE_NOT_FOUND = 404; +}