Skip to content

Commit

Permalink
refactor: return a sequence of replacements
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Nov 26, 2021
1 parent de3ead7 commit 37f4481
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
12 changes: 7 additions & 5 deletions protoletariat/fdsetgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import astor
from google.protobuf.descriptor_pb2 import FileDescriptorSet

from .rewrite import ImportRewriter, build_import_rewrite
from .rewrite import ImportRewriter, build_rewrites

_PROTO_SUFFIX_PATTERN = re.compile(r"^(.+)\.proto$")

Expand Down Expand Up @@ -46,12 +46,14 @@ def fix_imports(
# module, but they import it so we register a rewrite for the
# current proto as a dependency of itself to handle the case
# of services
rewriter.register_import_rewrite(build_import_rewrite(fd_name, fd_name))
for repl in build_rewrites(fd_name, fd_name):
rewriter.register_rewrite(repl)

# register _proto_ import rewrites
for dep in fd.dependency:
dep_name = _remove_proto_suffix(dep)
rewriter.register_import_rewrite(
build_import_rewrite(fd_name, dep_name)
)
for repl in build_rewrites(fd_name, dep_name):
rewriter.register_rewrite(repl)

for fd in fdset.file:
fd_name = _remove_proto_suffix(fd.name)
Expand Down
6 changes: 3 additions & 3 deletions protoletariat/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __call__(self, node: AST) -> AST:
return node


def build_import_rewrite(proto: str, dep: str) -> Replacement:
def build_rewrites(proto: str, dep: str) -> Sequence[Replacement]:
"""Construct a replacement import for `dep`.
Parameters
Expand Down Expand Up @@ -119,7 +119,7 @@ def build_import_rewrite(proto: str, dep: str) -> Replacement:
old = f"from {from_} import {part}_pb2 as {as_}"
new = f"from {leading_dots}{from_} import {part}_pb2 as {as_}"

return Replacement(old=old, new=new)
return [Replacement(old=old, new=new)]


class ImportRewriter(ast.NodeTransformer):
Expand All @@ -128,7 +128,7 @@ class ImportRewriter(ast.NodeTransformer):
def __init__(self) -> None:
self.rewrite = Rewriter()

def register_import_rewrite(self, replacement: Replacement) -> None:
def register_rewrite(self, replacement: Replacement) -> None:
"""Register a rewrite rule for turning `old` into `new`."""
(old_import,) = ast.parse(replacement.old).body
(new_import,) = ast.parse(replacement.new).body
Expand Down

0 comments on commit 37f4481

Please sign in to comment.