-
Notifications
You must be signed in to change notification settings - Fork 7
/
rewrite.py
174 lines (130 loc) · 5.12 KB
/
rewrite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from __future__ import annotations
import collections
import collections.abc
import typing
from typing import Any, Callable, NamedTuple, Sequence, Union
import typed_astunparse as astunparse
from typed_ast import ast3 as ast
from typed_ast.ast3 import AST
Node = Union[AST, Sequence[AST]]
class Replacement(NamedTuple):
old: str
new: str
class ASTReplacement(Replacement):
pass
class StringReplacement(Replacement):
pass
def _is_iterable(value: Any) -> bool: # type: ignore[misc]
"""Return whether `value` is a non-string iterable."""
return not isinstance(value, str) and isinstance(value, collections.abc.Iterable)
def matches(value: Node, pattern: Node) -> bool:
"""Check whether `value` matches `pattern`.
Parameters
----------
value
Instance to match
pattern
Pattern to match against
"""
# types must match exactly
if type(value) != type(pattern):
return False
# recur into lists
if _is_iterable(value) and _is_iterable(pattern):
assert isinstance(
value, collections.abc.Iterable
), f"value is not a non-str iterable {type(value).__name__}"
assert isinstance(
pattern, collections.abc.Iterable
), f"pattern is not a non-str iterable {type(pattern).__name__}"
return all(map(matches, value, pattern))
# primitive value, such as None, True, False, etc.
if not isinstance(value, AST) and not isinstance(pattern, AST):
return value == pattern
assert isinstance(value, AST), f"value is not an AST node: {type(value).__name__}"
assert isinstance(
pattern, AST
), f"pattern is not an AST node: {type(pattern).__name__}"
return all(
matches(getattr(value, field), getattr(pattern, field))
for field in pattern._fields
if hasattr(pattern, field)
)
class ASTRewriter:
"""AST pattern matching to enable rewrite rules."""
def __init__(self) -> None:
self.funcs: list[tuple[AST, Callable[[AST], AST]]] = []
def register(
self, pattern: AST
) -> Callable[[Callable[[AST], AST]], Callable[[AST], AST]]:
def wrapper(f: Callable[[AST], AST]) -> Callable[[AST], AST]:
self.funcs.append((pattern, f))
return f
return wrapper
def __call__(self, node: AST) -> AST:
for pattern, func in self.funcs:
if matches(node, pattern):
return func(node)
return node
def build_rewrites(proto: str, dep: str) -> Sequence[Replacement]:
"""Construct a replacement import for `dep`.
Parameters
----------
dep
The `name` field of a `FileDescriptorProto` message, stripped
of its ``.proto`` suffix.
Returns
-------
Replacement
A named tuple of the old import and its replacement
"""
parts = dep.split("/")
*import_parts, part = parts
# compute the number of dots needed to get to the package root
num_leading_dots = proto.count("/") + 1
leading_dots = "." * num_leading_dots
last_part = "__".join(f"{part}_pb2".split("_"))
if not import_parts:
# underscores are doubled by protoc/buf
old = f"import {part}_pb2 as {last_part}"
new = f"from {leading_dots} import {part}_pb2 as {last_part}"
else:
from_ = ".".join(import_parts)
as_ = f"{'_dot_'.join(import_parts)}_dot_{last_part}"
old = f"from {from_} import {part}_pb2 as {as_}"
new = f"from {leading_dots}{from_} import {part}_pb2 as {as_}"
return [
ASTReplacement(old=old, new=new),
ASTReplacement(
old=f"import {'.'.join(parts)}_pb2",
new=(
f"from {leading_dots or '.'} import {parts[0]}"
+ "_pb2" * (not import_parts)
),
),
]
class ImportNodeTransformer(ast.NodeTransformer):
"""A NodeTransformer to apply rewrite rules."""
def __init__(self, ast_rewriter: ASTRewriter) -> None:
self.ast_rewriter = ast_rewriter
def visit_Import(self, node: ast.Import) -> AST:
return self.ast_rewriter(node)
def visit_ImportFrom(self, node: ast.ImportFrom) -> AST:
return self.ast_rewriter(node)
class ASTImportRewriter:
def __init__(self) -> None:
self.node_transformer = ImportNodeTransformer(ASTRewriter())
def register_rewrite(self, replacement: Replacement) -> None:
"""Register a rewrite rule for turning `old` into `new`."""
(old_node,) = typing.cast(ast.Module, ast.parse(replacement.old)).body
(new_node,) = typing.cast(ast.Module, ast.parse(replacement.new)).body
def _rewrite(_: AST, repl: AST = new_node) -> AST:
return repl
funcs = self.node_transformer.ast_rewriter.funcs
if all(not matches(old_node, pat) for pat, _ in funcs):
self.node_transformer.ast_rewriter.register(old_node)(_rewrite)
assert (
sum(matches(old_node, pat) for pat, _ in funcs) == 1
), f"more than one rewrite rule found for pattern `{replacement.old}`"
def rewrite(self, src: str) -> str:
return astunparse.unparse(self.node_transformer.visit(ast.parse(src)))