Skip to content

Commit

Permalink
fix: use ast replacement everywhere because typed_ast/typed_astunpars…
Browse files Browse the repository at this point in the history
…e performs correct roundtrip
  • Loading branch information
cpcloud committed Nov 28, 2021
1 parent 46b7222 commit 627bffb
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 84 deletions.
14 changes: 13 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 2 additions & 14 deletions protoletariat/fdsetgen.py
Expand Up @@ -9,12 +9,7 @@

from google.protobuf.descriptor_pb2 import FileDescriptorSet

from .rewrite import (
ASTImportRewriter,
ChainedImportRewriter,
StringReplaceImportRewriter,
build_rewrites,
)
from .rewrite import ASTImportRewriter, build_rewrites

_PROTO_SUFFIX_PATTERN = re.compile(r"^(.+)\.proto$")

Expand Down Expand Up @@ -44,14 +39,7 @@ def fix_imports(

for fd in fdset.file:
fd_name = _remove_proto_suffix(fd.name)
rewriters[fd_name] = rewriter = ChainedImportRewriter(
ASTImportRewriter(),
# FIXME: mypy_protoc generates broken pyi files for long lines
# StringReplaceImportRewriter is an architectural workaround
# to allow replacement of strings so that the pyis might be able
# to work
StringReplaceImportRewriter(),
)
rewriters[fd_name] = rewriter = ASTImportRewriter()
# services live outside of the corresponding generated Python
# module, but they import it so we register a rewrite for the
# current proto as a dependency of itself to handle the case
Expand Down
80 changes: 11 additions & 69 deletions protoletariat/rewrite.py
@@ -1,14 +1,13 @@
from __future__ import annotations

import abc
import ast
import collections
import collections.abc
import re
from ast import AST
from typing import Any, Callable, ClassVar, MutableSequence, NamedTuple, Sequence, Union
import typing
from typing import Any, Callable, NamedTuple, Sequence, Union

import astor
import typed_astunparse as astunparse
from typed_ast import ast3 as ast
from typed_ast.ast3 import AST

Node = Union[AST, Sequence[AST]]

Expand Down Expand Up @@ -129,7 +128,7 @@ def build_rewrites(proto: str, dep: str) -> Sequence[Replacement]:

return [
ASTReplacement(old=old, new=new),
StringReplacement(
ASTReplacement(
old=f"import {'.'.join(parts)}_pb2",
new=(
f"from {leading_dots or '.'} import {parts[0]}"
Expand All @@ -139,23 +138,6 @@ def build_rewrites(proto: str, dep: str) -> Sequence[Replacement]:
]


class BaseRewriter(abc.ABC):
replacement_type: ClassVar[type[Replacement]]

@abc.abstractmethod
def rewrite(self, src: str) -> str:
...

@abc.abstractmethod
def do_register_rewrite(self, replacement: Replacement) -> None:
...

def register_rewrite(self, replacement: Replacement) -> None:
"""Register a rewrite rule for turning `old` into `new`."""
if isinstance(replacement, self.__class__.replacement_type):
self.do_register_rewrite(replacement)


class ImportNodeTransformer(ast.NodeTransformer):
"""A NodeTransformer to apply rewrite rules."""

Expand All @@ -169,16 +151,14 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> AST:
return self.ast_rewriter(node)


class ASTImportRewriter(BaseRewriter):
replacement_type: ClassVar[type[Replacement]] = ASTReplacement

class ASTImportRewriter:
def __init__(self) -> None:
self.node_transformer = ImportNodeTransformer(ASTRewriter())

def do_register_rewrite(self, replacement: Replacement) -> None:
def register_rewrite(self, replacement: Replacement) -> None:
"""Register a rewrite rule for turning `old` into `new`."""
(old_node,) = ast.parse(replacement.old).body
(new_node,) = ast.parse(replacement.new).body
(old_node,) = typing.cast(ast.Module, ast.parse(replacement.old)).body
(new_node,) = typing.cast(ast.Module, ast.parse(replacement.new)).body

def _rewrite(_: AST, repl: AST = new_node) -> AST:
return repl
Expand All @@ -191,42 +171,4 @@ def _rewrite(_: AST, repl: AST = new_node) -> AST:
), f"more than one rewrite rule found for pattern `{replacement.old}`"

def rewrite(self, src: str) -> str:
return astor.to_source(self.node_transformer.visit(ast.parse(src)))


class StringReplaceImportRewriter(BaseRewriter):
replacement_type: ClassVar[type[Replacement]] = StringReplacement

def __init__(self) -> None:
self.replacements: MutableSequence[tuple[re.Pattern[str], str]] = []

def do_register_rewrite(self, replacement: Replacement) -> None:
"""Register a rewrite rule for turning `old` into `new`."""
self.replacements.append(
(
re.compile(f"^{re.escape(replacement.old)}$", flags=re.MULTILINE),
replacement.new,
)
)

def rewrite(self, src: str) -> str:
for pattern, new in self.replacements:
if pattern.search(src) is not None:
src = pattern.sub(new, src)
return src


class ChainedImportRewriter(BaseRewriter):
replacement_type: ClassVar[type[Replacement]] = Replacement

def __init__(self, *rewriters: BaseRewriter) -> None:
self.rewriters = rewriters

def do_register_rewrite(self, replacement: Replacement) -> None:
for rewriter in self.rewriters:
rewriter.register_rewrite(replacement)

def rewrite(self, src: str) -> str:
for rewriter in self.rewriters:
src = rewriter.rewrite(src)
return src
return astunparse.unparse(self.node_transformer.visit(ast.parse(src)))
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -35,6 +35,7 @@ pytest = "^6.2.5"
pytest-randomly = "^3.10.1"
pyupgrade = "^2.26.0"
types-protobuf = "^3.18.1"
types-typed-ast = "^1.5.0"

[tool.poetry.scripts]
protol = "protoletariat.__main__:main"
Expand Down

0 comments on commit 627bffb

Please sign in to comment.