Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
#!/usr/bin/env python3
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import logging
import pathlib
import re
from functools import wraps
from typing import Callable, List, Optional, Type, TypeVar, Union, cast
from fissix.fixer_base import BaseFix
from fissix.fixer_util import Attr, Comma, Dot, LParen, Name, Newline, RParen
from fissix.pytree import Leaf, Node, type_repr
from .helpers import (
Once,
dotted_parts,
find_first,
find_last,
find_previous,
get_class,
power_parts,
print_selector_pattern,
print_tree,
quoted_parts,
)
from .imr import FunctionArgument, FunctionSpec
from .tool import BowlerTool
from .types import (
LN,
SENTINEL,
START,
SYMBOL,
TOKEN,
BowlerException,
Callback,
Capture,
Filename,
FilenameMatcher,
Filter,
Hunk,
Processor,
Stringish,
Transform,
)
SELECTORS = {}
Q = TypeVar("Q", bound="Query")
QM = Callable[..., Q]
log = logging.getLogger(__name__)
def selector(pattern: str) -> Callable[[QM], QM]:
def wrapper(fn: QM) -> QM:
selector = fn.__name__.replace("select_", "").lower()
SELECTORS[selector] = pattern
signature = inspect.signature(fn)
arg_names = list(signature.parameters)[1:]
@wraps(fn)
def wrapped(self: Q, *args, **kwargs) -> Q:
for arg, value in zip(arg_names, args):
if hasattr(value, "__name__"):
kwargs["source"] = value
kwargs[arg] = value.__name__
else:
kwargs[arg] = str(value)
if "name" in kwargs:
kwargs["dotted_name"] = " ".join(quoted_parts(kwargs["name"]))
kwargs["power_name"] = " ".join(power_parts(kwargs["name"]))
self.transforms.append(Transform(selector, kwargs))
return self
return wrapped
return wrapper
class Query:
def __init__(
self,
*paths: Union[str, List[str]],
filename_matcher: Optional[FilenameMatcher] = None,
python_version: int = 3,
) -> None:
self.paths: List[str] = []
self.transforms: List[Transform] = []
self.processors: List[Processor] = []
self.retcode: Optional[int] = None
self.filename_matcher = filename_matcher
self.python_version = python_version
self.exceptions: List[BowlerException] = []
for path in paths:
if isinstance(path, str):
self.paths.append(path)
elif isinstance(path, pathlib.Path):
self.paths.append(str(path))
else:
self.paths.extend(path)
if not self.paths:
self.paths.append(".")
@selector(
"""
file_input< any* >
"""
)
def select_root(self) -> "Query":
...
@selector(
"""
(
import_name< 'import'
(
module_name='{name}'
|
module_name=dotted_name< {dotted_name} any* >
|
dotted_as_name<
(
module_name='{name}'
|
module_name=dotted_name< {dotted_name} any* >
)
'as' module_nickname=any
>
)
>
|
import_from< 'from'
(
module_name='{name}'
|
module_name=dotted_name< {dotted_name} any* >
)
'import' ['(']
(
import_as_name<
module_import=any
'as'
module_nickname=any
>*
|
import_as_names<
module_imports=any*
>
|
module_import=any
)
[')'] >
|
module_name=power<
[TOKEN]
{power_name}
module_access=trailer< any* >*
>
)
"""
)
def select_module(self, name: str) -> "Query":
...
@selector(
"""
(
class_def=classdef<
'class' class_name='{name}'
any*
suite<
any*
>
any*
>
|
class_call=power<
class_name='{name}'
trailer< '(' class_arguments=any* ')' >
>
|
class_subclass=classdef< any*
(
any* class_name='{name}' any*
|
arglist< any* class_name='{name}' any* >
)
any*
suite<
any*
>
any*
>
|
class_import=import_from<
'from' module_name=any
'import' ['(']
(
import_as_names<
any*
class_name='{name}'
any*
>
|
any*
class_name='{name}'
any*
)
[')'] >
)
"""
)
def select_class(self, name: str) -> "Query":
...
@selector(
"""
(
class_def=classdef<
'class' class_name=any '('
(
any* class_ancestor='{name}' any*
|
arglist< any* class_ancestor='{name}' any* >
)
any*
suite<
any*
>
any*
>
)
"""
)
def select_subclass(self, name: str) -> "Query":
...
@selector(
"""
(
attr_class=classdef< any*
suite< any*
simple_stmt< any*
attr_assignment=expr_stmt<
attr_name='{name}' attr_value=any*
>
any* >
any* >
any* >
|
attr_assignment=expr_stmt<
power<
any*
trailer< any* >*
trailer< '.'
attr_name='{name}'
>
>
attr_value=any*
>
|
attr_access=power<
any*
trailer< any* >*
trailer< '.'
attr_name='{name}'
any* >
any* >
)
"""
)
def select_attribute(self, name: str) -> "Query":
...
@selector(
"""
(
decorated=decorated<
decorators=decorators
function_def=funcdef<
'def' function_name='{name}'
function_parameters=parameters< '('
function_arguments=typedargslist< ( 'self' | 'cls' ) any* >*
')' >
any*
>
>
|
function_def=funcdef<
'def' function_name='{name}'
function_parameters=parameters< '('
function_arguments=typedargslist< ( 'self' | 'cls' ) any* >*
')' >
any*
>
|
function_call=power<
any*
trailer< any* >*
trailer<
'.' function_name='{name}'
>
function_parameters=trailer< '(' function_arguments=any* ')' >
any*
>
|
function_import=import_from<
'from' module_name=any
'import' ['(']
(
import_as_names<
any*
function_name='{name}'
any*
>
|
any*
function_name='{name}'
any*
)
[')'] >
)
"""
)
def select_method(self, name: str) -> "Query":
...
@selector(
"""
(
decorated=decorated<
decorators=decorators
function_def=funcdef<
'def' function_name='{name}'
function_parameters=parameters< '(' function_arguments=any* ')' >
any*
>
>
|
function_def=funcdef<
'def' function_name='{name}'
function_parameters=parameters< '(' function_arguments=any* ')' >
any*
>
|
function_call=power<
[TOKEN]
function_name='{name}'
function_parameters=trailer< '(' function_arguments=any* ')' >
remainder=any*
>
|
function_import=import_from<
'from' module_name=any
'import' ['(']
(
import_as_names<
any*
function_name='{name}'
any*
>
|
any*
function_name='{name}'
any*
)
[')'] >
)
"""
)
def select_function(self, name: str) -> "Query":
...
@selector(
"""
(
var_assignment=expr_stmt<
var_name='{name}'
var_value=any*
>
|
var_name='{name}'
)
"""
)
def select_var(self, name: str) -> "Query":
...
@selector("""{pattern}""")
def select_pattern(self, pattern: str) -> "Query":
...
def select(self, pattern: str) -> "Query":
return self.select_pattern(pattern)
@property
def current(self) -> Transform:
if not self.transforms:
raise ValueError("no selectors used")
return self.transforms[-1]
def is_filename(self, include: str = None, exclude: str = None) -> "Query":
if include:
regex = re.compile(include)
def filter_filename_include(
node: LN, capture: Capture, filename: Filename
) -> bool:
return regex.search(filename) is not None
self.current.filters.append(filter_filename_include)
if exclude:
regex = re.compile(exclude)
def filter_filename_exclude(
node: LN, capture: Capture, filename: Filename
) -> bool:
return regex.search(filename) is None
self.current.filters.append(filter_filename_exclude)
return self
def is_call(self) -> "Query":
def filter_is_call(node: LN, capture: Capture, filename: Filename) -> bool:
return bool("function_call" in capture or "class_call" in capture)
self.current.filters.append(filter_is_call)
return self
def is_def(self) -> "Query":
def filter_is_def(node: LN, capture: Capture, filename: Filename) -> bool:
return bool("function_def" in capture or "class_def" in capture)
self.current.filters.append(filter_is_def)
return self
def in_class(self, class_name: str, include_subclasses: bool = True) -> "Query":
def filter_in_class(node: LN, capture: Capture, filename: Filename) -> bool:
while node.parent is not None:
if node.type == SYMBOL.classdef:
if node.children[1].value == class_name:
return True
if not include_subclasses:
return False
for leaf in node.leaves():
if leaf.type == TOKEN.COLON:
break
elif leaf.type == TOKEN.NAME and leaf.value == class_name:
return True
return False
node = node.parent
return False
self.current.filters.append(filter_in_class)
return self
def encapsulate(self, internal_name: str = "") -> "Query":
transform = self.current
if transform.selector not in ("attribute"):
raise ValueError("encapsulate requires select_attribute")
if not any("filter_in_class" in f.__name__ for f in transform.filters):
raise ValueError("encapsulate requires in_class filter")
make_property = Once()
old_name = transform.kwargs["name"]
new_name = internal_name or f"_{old_name}"
if new_name.startswith("__"):
raise ValueError(
"renaming {old_name} -> {new_name} is dangerous, "
"please specify internal_name to avoid name mangling"
)
def encapsulate_transform(
node: LN, capture: Capture, filename: Filename
) -> None:
if "attr_assignment" in capture:
leaf = capture["attr_name"]
leaf.replace(Name(new_name, prefix=leaf.prefix))
if make_property:
# TODO: capture and use type annotation from original assignment
class_node = get_class(node)
suite = find_first(class_node, SYMBOL.suite)
assert isinstance(suite, Node)
indent_node = find_first(suite, TOKEN.INDENT)
assert isinstance(indent_node, Leaf)
indent = indent_node.value
getter = Node(
SYMBOL.decorated,
[
Node(
SYMBOL.decorator,
[
Leaf(TOKEN.INDENT, indent),
Leaf(TOKEN.AT, "@"),
Name("property"),
Leaf(TOKEN.NEWLINE, "\n"),
],
),
Node(
SYMBOL.funcdef,
[
Name("def", indent),
Name(old_name, prefix=" "),
Node(
SYMBOL.parameters,
[LParen(), Name("self"), RParen()],
),
Leaf(TOKEN.COLON, ":"),
Node(
SYMBOL.suite,
[
Newline(),
Leaf(TOKEN.INDENT, indent + " "),
Node(
SYMBOL.simple_stmt,
[
Node(
SYMBOL.return_stmt,
[
Name("return"),
Node(
SYMBOL.power,
Attr(
Name("self"),
Name(new_name),
),
prefix=" ",
),
],
),
Newline(),
],
),
Leaf(TOKEN.DEDENT, "\n" + indent),
],
),
],
prefix=indent,
),
],
)
setter = Node(
SYMBOL.decorated,
[
Node(
SYMBOL.decorator,
[
Leaf(TOKEN.AT, "@"),
Node(
SYMBOL.dotted_name,
[Name(old_name), Dot(), Name("setter")],
),
Leaf(TOKEN.NEWLINE, "\n"),
],
),
Node(
SYMBOL.funcdef,
[
Name("def", indent),
Name(old_name, prefix=" "),
Node(
SYMBOL.parameters,
[
LParen(),
Node(
SYMBOL.typedargslist,
[
Name("self"),
Comma(),
Name("value", prefix=" "),
],
),
RParen(),
],
),
Leaf(TOKEN.COLON, ":"),
Node(
SYMBOL.suite,
[
Newline(),
Leaf(TOKEN.INDENT, indent + " "),
Node(
SYMBOL.simple_stmt,
[
Node(
SYMBOL.expr_stmt,
[
Node(
SYMBOL.power,
Attr(
Name("self"),
Name(new_name),
),
),
Leaf(
TOKEN.EQUAL,
"=",
prefix=" ",
),
Name("value", prefix=" "),
],
),
Newline(),
],
),
Leaf(TOKEN.DEDENT, "\n" + indent),
],
),
],
prefix=indent,
),
],
)
suite.insert_child(-1, getter)
suite.insert_child(-1, setter)
prev = find_previous(getter, TOKEN.DEDENT, recursive=True)
curr = find_last(setter, TOKEN.DEDENT, recursive=True)
if prev and curr:
assert isinstance(prev, Leaf) and isinstance(curr, Leaf)
prev.prefix, curr.prefix = curr.prefix, prev.prefix
prev.value, curr.value = curr.value, prev.value
transform.callbacks.append(encapsulate_transform)
return self
def rename(self, new_name: str) -> "Query":
transform = self.current
old_name = transform.kwargs["name"]
def rename_transform(node: LN, capture: Capture, filename: Filename) -> None:
log.debug(f"{filename} [{list(capture)}]: {node}")
# If two keys reference the same underlying object, do not modify it twice
visited: List[LN] = []
for _key, value in capture.items():
log.debug(f"{_key}: {value}")
if value in visited:
continue
visited.append(value)
if isinstance(value, Leaf) and value.type == TOKEN.NAME:
if value.value == old_name and value.parent is not None:
value.replace(Name(new_name, prefix=value.prefix))
break
elif isinstance(value, Node):
if type_repr(value.type) == "dotted_name":
dp_old = dotted_parts(old_name)
dp_new = dotted_parts(new_name)
parts = zip(dp_old, dp_new, value.children)
for old, new, leaf in parts:
if old != leaf.value:
break
if old != new:
leaf.replace(Name(new, prefix=leaf.prefix))
if len(dp_new) < len(dp_old):
# if new path is shorter, remove excess children
del value.children[len(dp_new) : len(dp_old)]
elif len(dp_new) > len(dp_old):
# if new path is longer, add new children
children = [
Name(new) for new in dp_new[len(dp_old) : len(dp_new)]
]
value.children[len(dp_old) : len(dp_old)] = children
elif type_repr(value.type) == "power":
# We don't actually need the '.' so just skip it
dp_old = old_name.split(".")
dp_new = new_name.split(".")
for old, new, leaf in zip(dp_old, dp_new, value.children):
if isinstance(leaf, Node):
name_leaf = leaf.children[1]
else:
name_leaf = leaf
if old != name_leaf.value:
break
name_leaf.replace(Name(new, prefix=name_leaf.prefix))
if len(dp_new) < len(dp_old):
# if new path is shorter, remove excess children
del value.children[len(dp_new) : len(dp_old)]
elif len(dp_new) > len(dp_old):
# if new path is longer, add new trailers in the middle
for i in range(len(dp_old), len(dp_new)):
value.insert_child(
i, Node(SYMBOL.trailer, [Dot(), Name(dp_new[i])])
)
transform.callbacks.append(rename_transform)
return self
def add_argument(
self,
name: str,
value: str,
positional: bool = False,
after: Stringish = SENTINEL,
type_annotation: Stringish = SENTINEL,
) -> "Query":
keyword = not positional
transform = self.current
if transform.selector not in ("function", "method"):
raise ValueError("add_argument must follow select_function/select_method")
# determine correct position (excluding self/cls) to add new argument
stop_at = -1
if positional and after not in (SENTINEL, START):
if "source" not in transform.kwargs:
raise ValueError(
"using after= with positional= requires passing original function"
)
signature = inspect.signature(transform.kwargs["source"])
if after not in (SENTINEL, START) and after not in signature.parameters:
raise ValueError(f"{after} does not exist in original function")
names = list(signature.parameters)
stop_at = names.index(cast(str, after))
if names[0] in ("self", "cls", "meta"):
stop_at -= 1
def add_argument_transform(
node: Node, capture: Capture, filename: Filename
) -> None:
if "function_def" not in capture and "function_call" not in capture:
return
spec = FunctionSpec.build(node, capture)
done = False
value_leaf = Name(value)
if spec.is_def:
new_arg = FunctionArgument(
name,
value_leaf if keyword else None,
cast(str, type_annotation) if type_annotation != SENTINEL else "",
)
for index, argument in enumerate(spec.arguments):
if after == argument.name:
spec.arguments.insert(index + 1, new_arg)
done = True
break
if (
after == START
or (positional and (argument.value or argument.star))
or (
keyword
and argument.star
and argument.star.type == TOKEN.DOUBLESTAR
)
):
spec.arguments.insert(index, new_arg)
done = True
break
if not done:
spec.arguments.append(new_arg)
elif positional:
new_arg = FunctionArgument(value=value_leaf)
for index, argument in enumerate(spec.arguments):
if argument.star and argument.star.type == TOKEN.STAR:
log.debug(f"noping out due to *{argument.name}")
done = True
break
if index == stop_at:
spec.arguments.insert(index + 1, new_arg)
done = True
break
if after == START or argument.name or argument.star:
spec.arguments.insert(index, new_arg)
done = True
break
if not done:
spec.arguments.append(new_arg)
spec.explode()
transform.callbacks.append(add_argument_transform)
return self
def modify_argument(
self,
name: str,
new_name: Stringish = SENTINEL,
type_annotation: Stringish = SENTINEL,
default_value: Stringish = SENTINEL,
) -> "Query":
transform = self.current
if transform.selector not in ("function", "method"):
raise ValueError(f"modifier must follow select_function or select_method")
def modify_argument_transform(
node: Node, capture: Capture, filename: Filename
) -> None:
if "function_def" not in capture and "function_call" not in capture:
return
spec = FunctionSpec.build(node, capture)
for argument in spec.arguments:
if argument.name == name:
if new_name is not SENTINEL:
argument.name = str(new_name)
if spec.is_def and type_annotation is not SENTINEL:
argument.annotation = str(type_annotation)
if spec.is_def and default_value is not SENTINEL:
argument.value = Name(default_value, prefix=" ")
spec.explode()
transform.callbacks.append(modify_argument_transform)
return self
def remove_argument(self, name: str) -> "Query":
transform = self.current
if transform.selector not in ("function", "method"):
raise ValueError(f"modifier must follow select_function or select_method")
# determine correct position (excluding self/cls) to add new argument
stop_at = -1
if "source" not in transform.kwargs:
raise ValueError("remove_argument requires passing original function")
signature = inspect.signature(transform.kwargs["source"])
if name not in signature.parameters:
raise ValueError(f"{name} does not exist in original function")
if signature.parameters[name].kind in (
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.VAR_POSITIONAL,
):
raise ValueError("can't remove *args or **kwargs")
positional = signature.parameters[name].kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
names = list(signature.parameters)
stop_at = names.index(name)
if names[0] in ("self", "cls", "meta"):
stop_at -= 1
def remove_argument_transform(
node: Node, capture: Capture, filename: Filename
) -> None:
if "function_def" not in capture and "function_call" not in capture:
return
spec = FunctionSpec.build(node, capture)
if spec.is_def or not positional:
for argument in spec.arguments:
if argument.name == name:
spec.arguments.remove(argument)
break
else:
for index, argument in reversed(list(enumerate(spec.arguments))):
if argument.name == name:
spec.arguments.pop(index)
break
if index == stop_at and not argument.name and not argument.star:
spec.arguments.pop(index)
break
spec.explode()
transform.callbacks.append(remove_argument_transform)
return self
def fixer(self, fx: Type[BaseFix]) -> "Query":
self.transforms.append(Transform(fixer=fx))
return self
def filter(self, filter_callback: Union[str, Filter]) -> "Query":
if isinstance(filter_callback, str):
code = compile(filter_callback, "<string>", "eval")
def callback(node: Node, capture: Capture, filename: Filename) -> bool:
return bool(eval(code)) # noqa: developer tool
filter_callback = cast(Filter, filter_callback)
self.current.filters.append(filter_callback)
return self
def modify(self, callback: Union[str, Callback]) -> "Query":
if isinstance(callback, str):
code = compile(callback, "<string>", "exec")
def callback(node: Node, capture: Capture, filename: Filename) -> None:
exec(code)
callback = cast(Callback, callback)
self.current.callbacks.append(callback)
return self
def process(self, callback: Processor) -> "Query":
self.processors.append(callback)
return self
def create_fixer(self, transform):
if transform.fixer:
bm_compat = transform.fixer.BM_compatible
pattern = transform.fixer.PATTERN
else:
bm_compat = False
log.debug(f"select {transform.selector}[{transform.kwargs}]")
pattern = SELECTORS[transform.selector].format(**transform.kwargs)
pattern = " ".join(
line
for wholeline in pattern.splitlines()
for line in (wholeline.strip(),)
if line
)
log.debug(f"generated pattern: {pattern}")
filters = transform.filters
callbacks = transform.callbacks
log.debug(f"registered {len(filters)} filters: {filters}")
log.debug(f"registered {len(callbacks)} callbacks: {callbacks}")
class Fixer(BaseFix):
PATTERN = pattern # type: ignore
BM_compatible = bm_compat
def transform(self, node: LN, capture: Capture) -> Optional[LN]:
filename = cast(Filename, self.filename)
returned_node = None
if not filters or all(f(node, capture, filename) for f in filters):
if transform.fixer:
returned_node = transform.fixer().transform(node, capture)
for callback in callbacks:
if returned_node and returned_node is not node:
raise BowlerException(
"Only the last fixer/callback may return "
"a different node. See "
"https://pybowler.io/docs/api-modifiers"
)
returned_node = callback(node, capture, filename)
return returned_node
return Fixer
def compile(self) -> List[Type[BaseFix]]:
if not self.transforms:
log.debug(f"no selectors chosen, defaulting to select_root")
self.select_root()
fixers: List[Type[BaseFix]] = []
for transform in self.transforms:
fixers.append(self.create_fixer(transform))
return fixers
def execute(self, **kwargs) -> "Query":
fixers = self.compile()
if self.processors:
def processor(filename: Filename, hunk: Hunk) -> bool:
apply = True
for p in self.processors:
if p(filename, hunk) is False:
apply = False
return apply
kwargs["hunk_processor"] = processor
kwargs.setdefault("filename_matcher", self.filename_matcher)
if self.python_version == 3:
kwargs.setdefault("options", {})["print_function"] = True
tool = BowlerTool(fixers, **kwargs)
self.retcode = tool.run(self.paths)
self.exceptions = tool.exceptions
return self
def dump(self, selector_pattern=False) -> "Query":
if not selector_pattern:
for transform in self.transforms:
transform.callbacks.append(print_tree)
else:
for transform in self.transforms:
transform.callbacks.append(print_selector_pattern)
return self.execute(write=False)
def diff(self, interactive: bool = False, **kwargs) -> "Query":
return self.execute(write=False, interactive=interactive, **kwargs)
def idiff(self, **kwargs) -> "Query":
return self.diff(interactive=True, **kwargs)
def silent(self, **kwargs) -> "Query":
return self.execute(silent=True, **kwargs)
def write(self, **kwargs) -> "Query":
return self.execute(write=True, silent=True, interactive=False, **kwargs)