/
loader.py
180 lines (124 loc) · 5.2 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import importlib
import pkgutil
import sys
from collections import defaultdict
from collections.abc import Generator
from importlib.metadata import entry_points
from inspect import getsourcefile, getsourcelines, signature
from pathlib import Path
from types import GenericAlias, ModuleType, UnionType
from typing import Any, TypeGuard
from mypy.nodes import Node
from refurb.visitor.mapping import METHOD_NODE_MAPPINGS
from . import checks as checks_module
from .error import Error, ErrorCategory, ErrorCode
from .settings import Settings
from .types import Check
def get_modules(paths: list[str]) -> Generator[ModuleType, None, None]:
sys.path.append(str(Path.cwd()))
plugins = [x.value for x in entry_points(group="refurb.plugins")]
extra_modules = (importlib.import_module(x) for x in paths + plugins)
loaded: set[ModuleType] = set()
for pkg in (checks_module, *extra_modules):
if pkg in loaded:
continue
if not hasattr(pkg, "__path__"):
module = importlib.import_module(pkg.__name__)
if module not in loaded:
loaded.add(module)
yield module
continue
for info in pkgutil.walk_packages(pkg.__path__, f"{pkg.__name__}."):
if info.ispkg:
continue
module = importlib.import_module(info.name)
if module not in loaded:
loaded.add(module)
yield module
loaded.add(pkg)
def is_valid_error_class(obj: Any) -> TypeGuard[type[Error]]: # type: ignore
if not hasattr(obj, "__name__"):
return False
name = obj.__name__
ignored_names = ("Error", "ErrorCode", "ErrorCategory")
return name.startswith("Error") and name not in ignored_names and issubclass(obj, Error)
def get_error_class(module: ModuleType) -> type[Error] | None:
for name in dir(module):
if name.startswith("Error") and name not in {"Error", "ErrorCode"}:
error = getattr(module, name)
if is_valid_error_class(error):
return error
return None
def should_load_check(settings: Settings, error: type[Error]) -> bool:
error_code = ErrorCode.from_error(error)
if error_code in settings.enable:
return True
if error_code in (settings.disable | settings.ignore):
return False
categories = {ErrorCategory(cat) for cat in error.categories}
if settings.enable & categories:
return True
if settings.disable & categories or settings.disable_all:
return False
return error.enabled or settings.enable_all
VALID_NODE_TYPES = set(METHOD_NODE_MAPPINGS.values())
VALID_OPTIONAL_ARGS = (("settings", Settings),)
def type_error_with_line_info(func: Any, msg: str) -> TypeError: # type: ignore
filename = getsourcefile(func)
line = getsourcelines(func)[1]
if not filename:
return TypeError(msg) # pragma: no cover
return TypeError(f"{filename}:{line}: {msg}")
def extract_function_types( # type: ignore
func: Any,
) -> Generator[type[Node], None, None]:
if not callable(func):
raise TypeError("Check function must be callable")
params = list(signature(func).parameters.values())
if len(params) not in {2, 3}:
raise type_error_with_line_info(func, "Check function must take 2-3 parameters")
node_param = params[0].annotation
error_param = params[1].annotation
optional_params = params[2:]
if not (
type(error_param) == GenericAlias
and error_param.__origin__ is list
and error_param.__args__[0] is Error
):
raise type_error_with_line_info(func, '"error" param must be of type list[Error]')
for param in optional_params:
if (param.name, param.annotation) not in VALID_OPTIONAL_ARGS:
raise type_error_with_line_info(
func,
f'"{param.name}: {param.annotation.__name__}" is not a valid service', # noqa: E501
)
match node_param:
case UnionType() as types:
for ty in types.__args__:
if ty not in VALID_NODE_TYPES:
raise type_error_with_line_info(
func,
f'"{ty.__name__}" is not a valid Mypy node type',
)
yield ty
case ty if ty in VALID_NODE_TYPES:
yield ty
case _:
raise type_error_with_line_info(
func,
f'"{ty.__name__}" is not a valid Mypy node type',
)
def load_checks(settings: Settings) -> defaultdict[type[Node], list[Check]]:
found: defaultdict[type[Node], list[Check]] = defaultdict(list)
enabled_errors: set[str] = set()
for module in get_modules(settings.load):
error = get_error_class(module)
if error and should_load_check(settings, error):
if func := getattr(module, "check", None):
for ty in extract_function_types(func):
found[ty].append(func)
enabled_errors.add(str(ErrorCode.from_error(error)))
if settings.verbose:
msg = ", ".join(sorted(enabled_errors)) if enabled_errors else "No checks enabled"
print(f"Enabled checks: {msg}\n")
return found