Skip to content

Commit

Permalink
fix(codegen): deduplicate generated imports
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jun 22, 2022
1 parent 1404f23 commit e040217
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions protoletariat/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import collections.abc
import typing
from ast import AST
from typing import Any, Callable, NamedTuple, Sequence, Union
from typing import Any, Callable, MutableSet, NamedTuple, Sequence, Union

try:
from ast import unparse as astunparse
Expand Down Expand Up @@ -205,12 +205,18 @@ class ImportNodeTransformer(ast.NodeTransformer):

def __init__(self, ast_rewriter: ASTRewriter) -> None:
self.ast_rewriter = ast_rewriter
# track the results we've produced to avoid duplication of imports
self.seen: MutableSet[str] = set()

def visit_Import(self, node: ast.Import) -> AST:
return self.ast_rewriter.rewrite(node)
def visit_Import(self, node: ast.AST) -> AST | None:
result = self.ast_rewriter.rewrite(node)
code = astunparse(result)
if code not in self.seen:
self.seen.add(code)
return result
return None

def visit_ImportFrom(self, node: ast.ImportFrom) -> AST:
return self.ast_rewriter.rewrite(node)
visit_ImportFrom = visit_Import


class ASTImportRewriter:
Expand All @@ -233,4 +239,5 @@ def _rewrite(_: AST, repl: AST = new_node) -> AST:
), f"more than one rewrite rule found for pattern `{replacement.old}`"

def rewrite(self, src: str) -> str:
self.node_transformer.seen.clear()
return astunparse(self.node_transformer.visit(ast.parse(src)))

0 comments on commit e040217

Please sign in to comment.