Skip to content

Commit

Permalink
feat: basic rename when parent module is imported
Browse files Browse the repository at this point in the history
  • Loading branch information
Bruno Alla committed Feb 24, 2021
1 parent beb3c23 commit 52d5078
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 62 deletions.
169 changes: 117 additions & 52 deletions django_codemod/visitors/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Module to implement base functionality."""
from abc import ABC
from typing import Generator, Optional, Sequence, Tuple, Union
from typing import Generator, List, Optional, Sequence, Tuple, Union

from libcst import (
Arg,
Attribute,
BaseExpression,
BaseSmallStatement,
Call,
Expand Down Expand Up @@ -54,48 +55,99 @@ class BaseRenameTransformer(BaseDjCodemodTransformer, ABC):

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
*self.old_module_parts, self.old_name = self.rename_from.split(".")
*self.new_module_parts, self.new_name = self.rename_to.split(".")
self.ctx_key_imported_as = f"{self.rename_from}-imported_as"
*self.old_parent_module_parts, self.old_parent_name, _ = (
*self.old_module_parts,
self.old_name,
) = self.rename_from.split(".")
*self.new_parent_module_parts, self.new_parent_name, _ = (
*self.new_module_parts,
self.new_name,
) = self.rename_to.split(".")
self.ctx_key_import_scope = f"{self.rename_from}-import_scope"
self.ctx_key_name_matcher = f"{self.rename_from}-name_matcher"
self.ctx_key_new_func = f"{self.rename_from}-new_func"

@property
def entity_imported_as(self):
return self.context.scratch.get(self.ctx_key_imported_as, None)
def name_matcher(self):
return self.context.scratch.get(self.ctx_key_name_matcher, None)

@property
def is_imported_with_old_name(self):
is_imported = self.ctx_key_imported_as in self.context.scratch
return is_imported and not self.entity_imported_as
def new_func(self):
return self.context.scratch.get(self.ctx_key_new_func, None)

def leave_ImportFrom(
self, original_node: ImportFrom, updated_node: ImportFrom
) -> Union[BaseSmallStatement, RemovalSentinel]:
"""Update import statements for matching old module name."""
if not import_from_matches(updated_node, self.old_module_parts) or isinstance(
updated_node.names, ImportStar
):
if isinstance(updated_node.names, ImportStar):
return updated_node
# This is a match
new_names = list(self.gen_new_imported_names(updated_node.names))
self.save_import_scope(original_node)
if not new_names:
# Nothing left in the import statement: remove it
return RemoveFromParent()
# Some imports are left, update the statement
cleaned_names = self.tidy_new_imported_names(new_names)
return updated_node.with_changes(names=cleaned_names)

def gen_new_imported_names(
self, old_names: Sequence[ImportAlias]
if import_from_matches(updated_node, self.old_module_parts):
# This is a match
new_import_aliases = list(self.gen_new_imported_aliases(updated_node.names))
self.save_import_scope(original_node)
if not new_import_aliases:
# Nothing left in the import statement: remove it
return RemoveFromParent()
# Some imports are left, update the statement
new_import_aliases = clean_new_import_aliases(new_import_aliases)
return updated_node.with_changes(names=new_import_aliases)
# Now check for parent module
if import_from_matches(updated_node, self.old_parent_module_parts):
new_import_aliases = []
for import_alias in updated_node.names:
if import_alias.evaluated_name == self.old_parent_name:
module_name_str = (
import_alias.evaluated_alias or import_alias.evaluated_name
)
self.context.scratch[self.ctx_key_name_matcher] = m.Attribute(
value=m.Name(module_name_str),
attr=m.Name(self.old_name),
)
self.context.scratch[self.ctx_key_new_func] = Attribute(
attr=Name(self.new_name),
value=Name(
import_alias.evaluated_alias or self.new_parent_name
),
)
self.save_import_scope(original_node)
if self.old_parent_module_parts != self.new_parent_module_parts:
# import statement needs updating
AddImportsVisitor.add_needed_import(
context=self.context,
module=".".join(self.new_parent_module_parts),
obj=self.new_parent_name,
asname=import_alias.evaluated_alias,
)
continue
else:
new_import_aliases.append(import_alias)
else:
new_import_aliases.append(import_alias)
if not new_import_aliases:
# Nothing left in the import statement: remove it
return RemoveFromParent()
# Some imports are left, update the statement
new_import_aliases = clean_new_import_aliases(new_import_aliases)
return updated_node.with_changes(names=new_import_aliases)
return updated_node

def gen_new_imported_aliases(
self, import_aliases: Sequence[ImportAlias]
) -> Generator[ImportAlias, None, None]:
"""Update import if the entity we're interested in is imported."""
for import_alias in old_names:
for import_alias in import_aliases:
if not self.old_name or import_alias.evaluated_name == self.old_name:
self.context.scratch[self.ctx_key_imported_as] = import_alias.asname
if import_alias.evaluated_alias is None:
self.context.scratch[self.ctx_key_name_matcher] = m.Name(
self.old_name
)
if self.new_name:
self.context.scratch[self.ctx_key_new_func] = Name(
self.new_name
)
if self.rename_from != self.rename_to:
if self.simple_rename:
self.add_new_import(import_alias.evaluated_name)
self.add_new_import(import_alias)
continue
yield import_alias

Expand All @@ -115,41 +167,42 @@ def save_import_scope(self, import_from: ImportFrom) -> None:
def import_scope(self) -> Optional[Scope]:
return self.context.scratch.get(self.ctx_key_import_scope, None)

@staticmethod
def tidy_new_imported_names(
new_names: Sequence[ImportAlias],
) -> Sequence[ImportAlias]:
"""Tidy up the updated list of imports"""
# Sort them
cleaned_names = sorted(new_names, key=lambda n: n.evaluated_name)
# Remove any trailing commas
last_name = cleaned_names[-1]
if last_name.comma != MaybeSentinel.DEFAULT:
cleaned_names[-1] = last_name.with_changes(comma=MaybeSentinel.DEFAULT)
return cleaned_names

def add_new_import(self, evaluated_name: Optional[str] = None) -> None:
as_name = (
self.entity_imported_as.name.value if self.entity_imported_as else None
)
def add_new_import(self, old_import_alias: ImportAlias) -> None:
AddImportsVisitor.add_needed_import(
context=self.context,
module=".".join(self.new_module_parts),
obj=self.new_name or evaluated_name,
asname=as_name,
obj=self.new_name or old_import_alias.evaluated_name,
asname=old_import_alias.evaluated_alias,
)

def leave_Name(self, original_node: Name, updated_node: Name) -> BaseExpression:
"""Rename reference to the imported name."""
matcher = self.name_matcher
if (
self.is_imported_with_old_name
and m.matches(updated_node, m.Name(value=self.old_name))
matcher
and m.matches(updated_node, matcher)
and not self.is_wrapped_in_call(original_node)
and self.matches_import_scope(original_node)
):
return updated_node.with_changes(value=self.new_name)
return super().leave_Name(original_node, updated_node)

def leave_Attribute(
self, original_node: Attribute, updated_node: Attribute
) -> BaseExpression:
matcher = self.name_matcher
if (
matcher
and m.matches(updated_node, matcher)
and not self.is_wrapped_in_call(original_node)
and self.matches_import_scope(original_node)
):
return updated_node.with_changes(
value=self.new_parent_name,
attr=Name(self.new_name),
)
return super().leave_Attribute(original_node, updated_node)

def is_wrapped_in_call(self, node: CSTNode) -> bool:
"""Check whether given node is wrapped in Call."""
parent = self.resolve_parent_node(node)
Expand All @@ -169,6 +222,19 @@ def matches_import_scope(self, node: CSTNode) -> bool:
return scope == self.import_scope


def clean_new_import_aliases(
import_aliases: Sequence[ImportAlias],
) -> List[ImportAlias]:
"""Tidy up the updated list of imports"""
# Sort them
cleaned_import_aliases = sorted(import_aliases, key=lambda n: n.evaluated_name)
# Remove any trailing commas
last_name = cleaned_import_aliases[-1]
if last_name.comma != MaybeSentinel.DEFAULT:
cleaned_import_aliases[-1] = last_name.with_changes(comma=MaybeSentinel.DEFAULT)
return cleaned_import_aliases


class BaseModuleRenameTransformer(BaseRenameTransformer, ABC):
"""Base class to help rename or move a module."""

Expand All @@ -184,15 +250,14 @@ class BaseFuncRenameTransformer(BaseRenameTransformer, ABC):
"""Base class to help rename or move a function."""

def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression:
if self.is_imported_with_old_name and m.matches(
updated_node, m.Call(func=m.Name(self.old_name))
):
matcher = self.name_matcher
if m.matches(updated_node, m.Call(func=matcher)):
return self.update_call(updated_node=updated_node)
return super().leave_Call(original_node, updated_node)

def update_call(self, updated_node: Call) -> BaseExpression:
updated_args = self.update_call_args(updated_node)
return updated_node.with_changes(args=updated_args, func=Name(self.new_name))
return updated_node.with_changes(args=updated_args, func=self.new_func)

def update_call_args(self, node: Call) -> Sequence[Arg]:
return node.args
5 changes: 2 additions & 3 deletions django_codemod/visitors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ class AvailableAttrsTransformer(BaseRenameTransformer):
rename_to = "functools.WRAPPER_ASSIGNMENTS"

def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression:
if self.is_imported_with_old_name and m.matches(
updated_node, m.Call(func=m.Name(self.old_name))
):
matcher = self.name_matcher
if m.matches(updated_node, m.Call(func=matcher)):
return Name(self.new_name)
return super().leave_Call(original_node, updated_node)
18 changes: 12 additions & 6 deletions django_codemod/visitors/urls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence
from typing import Sequence, Tuple

from libcst import Arg, BaseExpression, Call, Name, SimpleString
from libcst import matchers as m
Expand Down Expand Up @@ -43,7 +43,11 @@ def update_call(self, updated_node: Call) -> BaseExpression:
return self.update_call_to_path(updated_node)
except PatternNotSupported:
# Safe fallback to re_path()
self.add_new_import()
AddImportsVisitor.add_needed_import(
context=self.context,
module=".".join(self.new_module_parts),
obj=self.new_name,
)
return super().update_call(updated_node)

def update_call_to_path(self, updated_node: Call) -> Call:
Expand All @@ -68,7 +72,7 @@ def build_path_call(self, pattern: str, other_args: Sequence[Arg]) -> Call:
updated_args = (Arg(value=SimpleString(f"'{route}'")), *other_args)
return Call(args=updated_args, func=Name("path"))

def build_route(self, pattern):
def build_route(self, pattern: str) -> str:
"""Build route from a URL pattern."""
stripped_pattern = pattern.lstrip("^").rstrip("$")
route = ""
Expand All @@ -80,7 +84,8 @@ def build_route(self, pattern):
self.check_route(route)
return route

def parse_next_group(self, left_to_parse):
@staticmethod
def parse_next_group(left_to_parse: str) -> Tuple[str, str]:
"""Extract captured group info."""
prefix, rest = left_to_parse.split("(?P<", 1)
group, left_to_parse = rest.split(")", 1)
Expand All @@ -91,7 +96,8 @@ def parse_next_group(self, left_to_parse):
raise PatternNotSupported("No converter found")
return prefix + f"<{converter}:{group_name}>", left_to_parse

def check_route(self, route):
@staticmethod
def check_route(route: str) -> None:
"""Check that route doesn't contain anymore regex."""
if set(route) & REGEX_SPECIALS_SANS_DASH:
raise PatternNotSupported(f"Route {route} contains regex")
Expand All @@ -101,5 +107,5 @@ def update_call_args(self, node: Call) -> Sequence[Arg]:
first_arg, *other_args = node.args
if m.matches(first_arg, m.Arg(keyword=m.Name("regex"))):
first_arg = Arg(value=first_arg.value)
return (first_arg, *other_args)
return [first_arg, *other_args]
return super().update_call_args(node)
41 changes: 40 additions & 1 deletion tests/visitors/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
def test_module_matcher(parts, expected_matcher):
matcher = module_matcher(parts)

# equality comparision doesn't work with matcher:
# equality comparison doesn't work with matcher:
# compare their representation seems to work
assert repr(matcher) == repr(expected_matcher)

Expand Down Expand Up @@ -60,6 +60,19 @@ def test_simple_substitution(self) -> None:
"""
self.assertCodemod(before, after)

def test_parent_module(self) -> None:
before = """
from django.dummy import module
result = module.func()
"""
after = """
from django.dummy import module
result = module.better_func()
"""
self.assertCodemod(before, after)

def test_reference_without_call(self) -> None:
"""Replace reference of the function even is it's not called."""
before = """
Expand Down Expand Up @@ -274,6 +287,32 @@ def test_simple_substitution(self) -> None:
"""
self.assertCodemod(before, after)

def test_parent_module(self) -> None:
before = """
from django.dummy import module
result = module.func()
"""
after = """
from django.better import dummy
result = dummy.better_func()
"""
self.assertCodemod(before, after)

def test_parent_module_import_alias(self) -> None:
before = """
from django.dummy import module as django_module
result = django_module.func()
"""
after = """
from django.better import dummy as django_module
result = django_module.better_func()
"""
self.assertCodemod(before, after)

def test_already_imported(self) -> None:
before = """
from django.dummy.module import func
Expand Down

0 comments on commit 52d5078

Please sign in to comment.