From e040217ea9007e01f657cfb4ffcba80a9c8be23f Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 22 Jun 2022 14:16:00 -0500 Subject: [PATCH] fix(codegen): deduplicate generated imports --- protoletariat/rewrite.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/protoletariat/rewrite.py b/protoletariat/rewrite.py index bbf07ec3..5a4d942f 100644 --- a/protoletariat/rewrite.py +++ b/protoletariat/rewrite.py @@ -5,7 +5,7 @@ import collections.abc import typing from ast import AST -from typing import Any, Callable, NamedTuple, Sequence, Union +from typing import Any, Callable, MutableSet, NamedTuple, Sequence, Union try: from ast import unparse as astunparse @@ -205,12 +205,18 @@ class ImportNodeTransformer(ast.NodeTransformer): def __init__(self, ast_rewriter: ASTRewriter) -> None: self.ast_rewriter = ast_rewriter + # track the results we've produced to avoid duplication of imports + self.seen: MutableSet[str] = set() - def visit_Import(self, node: ast.Import) -> AST: - return self.ast_rewriter.rewrite(node) + def visit_Import(self, node: ast.AST) -> AST | None: + result = self.ast_rewriter.rewrite(node) + code = astunparse(result) + if code not in self.seen: + self.seen.add(code) + return result + return None - def visit_ImportFrom(self, node: ast.ImportFrom) -> AST: - return self.ast_rewriter.rewrite(node) + visit_ImportFrom = visit_Import class ASTImportRewriter: @@ -233,4 +239,5 @@ def _rewrite(_: AST, repl: AST = new_node) -> AST: ), f"more than one rewrite rule found for pattern `{replacement.old}`" def rewrite(self, src: str) -> str: + self.node_transformer.seen.clear() return astunparse(self.node_transformer.visit(ast.parse(src)))