From f6c515965714677ce2a9deff2eef1bb2232d4acb Mon Sep 17 00:00:00 2001 From: John Vandenberg Date: Sat, 8 Jan 2022 18:40:56 +0800 Subject: [PATCH] feat: avoid imports in try blocks Fixes https://github.com/browniebroke/django-codemod/issues/147 --- django_codemod/visitors/base.py | 36 ++++++++++++++++++++++++++++++--- tests/visitors/test_base.py | 25 +++++++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/django_codemod/visitors/base.py b/django_codemod/visitors/base.py index 1522ae53..55c819cb 100644 --- a/django_codemod/visitors/base.py +++ b/django_codemod/visitors/base.py @@ -7,6 +7,7 @@ Attribute, BaseExpression, BaseSmallStatement, + BatchableMetadataProvider, Call, CSTNode, ImportAlias, @@ -16,15 +17,34 @@ Name, RemovalSentinel, RemoveFromParent, + Try, ) from libcst import matchers as m -from libcst.codemod import CodemodContext, ContextAwareTransformer +from libcst.codemod import CodemodContext, ContextAwareTransformer, SkipFile from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor from libcst.metadata import ParentNodeProvider, Scope, ScopeProvider from django_codemod.feature_flags import REPLACE_PARENT_MODULE_IMPORTED +class IsTryImportProvider(BatchableMetadataProvider[bool]): + """ + Marks ImportFrom nodes found inside a try block. + """ + def __init__(self) -> None: + super().__init__() + self.try_level = 0 + + def visit_Try(self, node: Try) -> None: + self.try_level += 1 + + def leave_Try(self, node: Try) -> None: + self.try_level -= 1 + + def visit_ImportFrom(self, node: ImportFrom) -> None: + self.set_metadata(node, bool(self.try_level)) + + class BaseDjCodemodTransformer(ContextAwareTransformer, ABC): deprecated_in: Tuple[int, int] removed_in: Tuple[int, int] @@ -50,6 +70,8 @@ def import_from_matches(node: ImportFrom, module_parts: Sequence[str]) -> bool: class BaseRenameTransformer(BaseDjCodemodTransformer, ABC): """Base class to help rename or move a declaration.""" + METADATA_DEPENDENCIES = (IsTryImportProvider, ) + rename_from: str rename_to: str @@ -91,10 +113,11 @@ def leave_ImportFrom( self, original_node: ImportFrom, updated_node: ImportFrom ) -> Union[BaseSmallStatement, RemovalSentinel]: """Update import statements for matching old module name.""" + return ( self._check_import_from_exact(original_node, updated_node) or self._check_import_from_parent(original_node, updated_node) - or self._check_import_from_child(updated_node) + or self._check_import_from_child(original_node, updated_node) or updated_node ) @@ -114,6 +137,8 @@ def _check_import_from_exact( # Check whether the exact symbol is imported if not import_from_matches(updated_node, self.old_module_parts): return None + if self.get_metadata(IsTryImportProvider, original_node): + raise SkipFile # Match, update the node an return it new_import_aliases = [] for import_alias in updated_node.names: @@ -162,6 +187,8 @@ def _check_import_from_parent( # Check whether parent module is imported if not import_from_matches(updated_node, self.old_parent_module_parts): return None + if self.get_metadata(IsTryImportProvider, original_node): + raise SkipFile # Match, check imports and extract metadata for import_alias in updated_node.names: if import_alias.evaluated_name == self.old_parent_name: @@ -208,7 +235,7 @@ def _check_import_from_parent( return updated_node def _check_import_from_child( - self, updated_node: ImportFrom + self, original_node: ImportFrom, updated_node: ImportFrom ) -> Optional[Union[BaseSmallStatement, RemovalSentinel]]: """ Check import of a member of the module being codemodded. @@ -223,6 +250,8 @@ def _check_import_from_child( # Check whether a member of the module is imported if not import_from_matches(updated_node, self.old_all_parts): return None + if self.get_metadata(IsTryImportProvider, original_node): + raise SkipFile # Match, add import for all imported names and remove the existing import for import_alias in updated_node.names: AddImportsVisitor.add_needed_import( @@ -260,6 +289,7 @@ def update_imports(self): def leave_Name(self, original_node: Name, updated_node: Name) -> BaseExpression: """Rename reference to the imported name.""" + matcher = self.name_matcher if ( matcher diff --git a/tests/visitors/test_base.py b/tests/visitors/test_base.py index 1f68f6bb..62b5c279 100644 --- a/tests/visitors/test_base.py +++ b/tests/visitors/test_base.py @@ -1,4 +1,5 @@ import pytest +from libcst.codemod import SkipFile from libcst import matchers as m from parameterized import parameterized @@ -60,6 +61,18 @@ def test_simple_substitution(self) -> None: """ self.assertCodemod(before, after) + def test_avoid_try_import(self) -> None: + before = after = """ + try: + from django.dummy.module import func + except: + from django.dummy.other_module import better_func as func + + result = func() + """ + with pytest.raises(SkipFile): + self.assertCodemod(before, after) + @pytest.mark.usefixtures("parent_module_import_enabled") def test_parent_module(self) -> None: before = """ @@ -424,6 +437,18 @@ def test_simple_substitution(self) -> None: """ self.assertCodemod(before, after) + def test_avoid_try_import(self) -> None: + before = after = """ + try: + from django.dummy.module import func + except: + from django.dummy.other_module import better_func as func + + result = func() + """ + with pytest.raises(SkipFile): + self.assertCodemod(before, after) + def test_parent_module_substitution(self) -> None: before = """ from django.dummy import module