diff --git a/protoletariat/rewrite.py b/protoletariat/rewrite.py index 809c723f..e427b5bc 100644 --- a/protoletariat/rewrite.py +++ b/protoletariat/rewrite.py @@ -130,7 +130,10 @@ def build_rewrites(proto: str, dep: str) -> Sequence[Replacement]: ASTReplacement(old=old, new=new), StringReplacement( old=f"import {'.'.join(parts)}_pb2", - new=f"from {leading_dots or '.'}{'.'.join(import_parts)} import {part}_pb2", + new=( + f"from {leading_dots or '.'} import {parts[0]}" + + "_pb2" * (not import_parts) + ), ), ] diff --git a/protoletariat/tests/test_rewrite.py b/protoletariat/tests/test_rewrite.py index bf77f7d9..a4d499da 100644 --- a/protoletariat/tests/test_rewrite.py +++ b/protoletariat/tests/test_rewrite.py @@ -28,13 +28,19 @@ def check_proto_out(out: Path) -> None: pytest.param( "a", "foo", - ["from . import foo_pb2 as foo__pb2", "from . import foo_pb2"], + [ + "from . import foo_pb2 as foo__pb2", + "from . import foo_pb2", + ], id="no_nesting", ), pytest.param( "a/b", "foo", - ["from .. import foo_pb2 as foo__pb2", "from .. import foo_pb2"], + [ + "from .. import foo_pb2 as foo__pb2", + "from .. import foo_pb2", + ], id="proto_one_level", ), pytest.param( @@ -42,7 +48,7 @@ def check_proto_out(out: Path) -> None: "foo/bar", [ "from .foo import bar_pb2 as foo_dot_bar__pb2", - "from .foo import bar_pb2", + "from . import foo", ], id="dep_one_level", ), @@ -51,7 +57,7 @@ def check_proto_out(out: Path) -> None: "foo/bar", [ "from ..foo import bar_pb2 as foo_dot_bar__pb2", - "from ..foo import bar_pb2", + "from .. import foo", ], id="both_one_level", ), @@ -60,7 +66,7 @@ def check_proto_out(out: Path) -> None: "foo/bar/baz", [ "from .foo.bar import baz_pb2 as foo_dot_bar_dot_baz__pb2", - "from .foo.bar import baz_pb2", + "from . import foo", ], id="dep_two_levels", ), @@ -69,7 +75,7 @@ def check_proto_out(out: Path) -> None: "foo/bar/baz", [ "from ..foo.bar import baz_pb2 as foo_dot_bar_dot_baz__pb2", - "from ..foo.bar import baz_pb2", + "from .. import foo", ], id="proto_one_level_dep_two_levels", ), @@ -78,7 +84,7 @@ def check_proto_out(out: Path) -> None: "foo/bar/bizz_buzz", [ "from .foo.bar import bizz_buzz_pb2 as foo_dot_bar_dot_bizz__buzz__pb2", - "from .foo.bar import bizz_buzz_pb2", + "from . import foo", ], id="dep_three_levels", ), @@ -90,14 +96,17 @@ def check_proto_out(out: Path) -> None: "from ..foo.bar import bizz_buzz_pb2 as " "foo_dot_bar_dot_bizz__buzz__pb2" ), - "from ..foo.bar import bizz_buzz_pb2", + "from .. import foo", ], id="proto_one_level_dep_three_levels", ), pytest.param( "a/b", "a/b", - ["from ..a import b_pb2 as a_dot_b__pb2", "from ..a import b_pb2"], + [ + "from ..a import b_pb2 as a_dot_b__pb2", + "from .. import a", + ], id="self_dep", ), ],