Skip to content

Commit

Permalink
feat: avoid imports in try blocks
Browse files Browse the repository at this point in the history
Fixes #147
  • Loading branch information
jayvdb authored and browniebroke committed Jan 10, 2022
1 parent 3541992 commit f6c5159
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
36 changes: 33 additions & 3 deletions django_codemod/visitors/base.py
Expand Up @@ -7,6 +7,7 @@
Attribute,
BaseExpression,
BaseSmallStatement,
BatchableMetadataProvider,
Call,
CSTNode,
ImportAlias,
Expand All @@ -16,15 +17,34 @@
Name,
RemovalSentinel,
RemoveFromParent,
Try,
)
from libcst import matchers as m
from libcst.codemod import CodemodContext, ContextAwareTransformer
from libcst.codemod import CodemodContext, ContextAwareTransformer, SkipFile
from libcst.codemod.visitors import AddImportsVisitor, RemoveImportsVisitor
from libcst.metadata import ParentNodeProvider, Scope, ScopeProvider

from django_codemod.feature_flags import REPLACE_PARENT_MODULE_IMPORTED


class IsTryImportProvider(BatchableMetadataProvider[bool]):
"""
Marks ImportFrom nodes found inside a try block.
"""
def __init__(self) -> None:
super().__init__()
self.try_level = 0

def visit_Try(self, node: Try) -> None:
self.try_level += 1

def leave_Try(self, node: Try) -> None:
self.try_level -= 1

def visit_ImportFrom(self, node: ImportFrom) -> None:
self.set_metadata(node, bool(self.try_level))


class BaseDjCodemodTransformer(ContextAwareTransformer, ABC):
deprecated_in: Tuple[int, int]
removed_in: Tuple[int, int]
Expand All @@ -50,6 +70,8 @@ def import_from_matches(node: ImportFrom, module_parts: Sequence[str]) -> bool:
class BaseRenameTransformer(BaseDjCodemodTransformer, ABC):
"""Base class to help rename or move a declaration."""

METADATA_DEPENDENCIES = (IsTryImportProvider, )

rename_from: str
rename_to: str

Expand Down Expand Up @@ -91,10 +113,11 @@ def leave_ImportFrom(
self, original_node: ImportFrom, updated_node: ImportFrom
) -> Union[BaseSmallStatement, RemovalSentinel]:
"""Update import statements for matching old module name."""

return (
self._check_import_from_exact(original_node, updated_node)
or self._check_import_from_parent(original_node, updated_node)
or self._check_import_from_child(updated_node)
or self._check_import_from_child(original_node, updated_node)
or updated_node
)

Expand All @@ -114,6 +137,8 @@ def _check_import_from_exact(
# Check whether the exact symbol is imported
if not import_from_matches(updated_node, self.old_module_parts):
return None
if self.get_metadata(IsTryImportProvider, original_node):
raise SkipFile
# Match, update the node an return it
new_import_aliases = []
for import_alias in updated_node.names:
Expand Down Expand Up @@ -162,6 +187,8 @@ def _check_import_from_parent(
# Check whether parent module is imported
if not import_from_matches(updated_node, self.old_parent_module_parts):
return None
if self.get_metadata(IsTryImportProvider, original_node):
raise SkipFile
# Match, check imports and extract metadata
for import_alias in updated_node.names:
if import_alias.evaluated_name == self.old_parent_name:
Expand Down Expand Up @@ -208,7 +235,7 @@ def _check_import_from_parent(
return updated_node

def _check_import_from_child(
self, updated_node: ImportFrom
self, original_node: ImportFrom, updated_node: ImportFrom
) -> Optional[Union[BaseSmallStatement, RemovalSentinel]]:
"""
Check import of a member of the module being codemodded.
Expand All @@ -223,6 +250,8 @@ def _check_import_from_child(
# Check whether a member of the module is imported
if not import_from_matches(updated_node, self.old_all_parts):
return None
if self.get_metadata(IsTryImportProvider, original_node):
raise SkipFile
# Match, add import for all imported names and remove the existing import
for import_alias in updated_node.names:
AddImportsVisitor.add_needed_import(
Expand Down Expand Up @@ -260,6 +289,7 @@ def update_imports(self):

def leave_Name(self, original_node: Name, updated_node: Name) -> BaseExpression:
"""Rename reference to the imported name."""

matcher = self.name_matcher
if (
matcher
Expand Down
25 changes: 25 additions & 0 deletions tests/visitors/test_base.py
@@ -1,4 +1,5 @@
import pytest
from libcst.codemod import SkipFile
from libcst import matchers as m
from parameterized import parameterized

Expand Down Expand Up @@ -60,6 +61,18 @@ def test_simple_substitution(self) -> None:
"""
self.assertCodemod(before, after)

def test_avoid_try_import(self) -> None:
before = after = """
try:
from django.dummy.module import func
except:
from django.dummy.other_module import better_func as func
result = func()
"""
with pytest.raises(SkipFile):
self.assertCodemod(before, after)

@pytest.mark.usefixtures("parent_module_import_enabled")
def test_parent_module(self) -> None:
before = """
Expand Down Expand Up @@ -424,6 +437,18 @@ def test_simple_substitution(self) -> None:
"""
self.assertCodemod(before, after)

def test_avoid_try_import(self) -> None:
before = after = """
try:
from django.dummy.module import func
except:
from django.dummy.other_module import better_func as func
result = func()
"""
with pytest.raises(SkipFile):
self.assertCodemod(before, after)

def test_parent_module_substitution(self) -> None:
before = """
from django.dummy import module
Expand Down

0 comments on commit f6c5159

Please sign in to comment.