-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
3,352 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
import ast | ||
import argparse | ||
from pathlib import Path | ||
from collections import defaultdict, deque | ||
|
||
import pydash | ||
|
||
|
||
WRAPPER_KW = "RES" | ||
INIT_FILE = "src/pydash/__init__.py" | ||
BASE_MODULE = """ | ||
'''Generated from the `scripts/chaining_type_generator.py` script''' | ||
import re | ||
import typing as t | ||
from typing_extensions import ParamSpec, Type | ||
import pydash as pyd | ||
from pydash.chaining.chaining import Chain | ||
from pydash.types import * | ||
from pydash.helpers import Unset, UNSET | ||
from pydash.functions import After, Ary, Before, Once, Spread, Throttle | ||
from _typeshed import ( | ||
SupportsDunderGE, | ||
SupportsDunderGT, | ||
SupportsDunderLE, | ||
SupportsDunderLT, | ||
SupportsRichComparison, | ||
SupportsAdd, | ||
SupportsRichComparisonT, | ||
SupportsSub, | ||
) | ||
{imports} | ||
Value_coT = t.TypeVar("Value_coT", covariant=True) | ||
T = t.TypeVar("T") | ||
T2 = t.TypeVar("T2") | ||
T3 = t.TypeVar("T3") | ||
T4 = t.TypeVar("T4") | ||
T5 = t.TypeVar("T5") | ||
NumT = t.TypeVar("NumT", int, float) | ||
CallableT = t.TypeVar("CallableT", bound=t.Callable) | ||
SequenceT = t.TypeVar("SequenceT", bound=t.Sequence) | ||
MutableSequenceT = t.TypeVar("MutableSequenceT", bound=t.MutableSequence) | ||
P = ParamSpec("P") | ||
class {class_name}: | ||
""" | ||
|
||
|
||
def build_header(class_name: str, imports: list[str]) -> str: | ||
return BASE_MODULE.format(class_name=class_name, imports="\n".join(imports)) | ||
|
||
|
||
def modules_and_api_funcs() -> dict[str, list[str]]: | ||
"""This is mostly so we don't have to import `pydash`""" | ||
|
||
with open(INIT_FILE, "r", encoding="utf-8") as source: | ||
tree = ast.parse(source.read()) | ||
|
||
module_to_funcs = defaultdict(list) | ||
|
||
for node in ast.walk(tree): | ||
# TODO: maybe handle `Import` as well, not necessary for now | ||
if isinstance(node, ast.ImportFrom): | ||
for name in node.names: | ||
module_to_funcs[node.module].append(name.asname or name.name) | ||
|
||
return module_to_funcs | ||
|
||
|
||
def is_overload(node: ast.FunctionDef) -> bool: | ||
return any( | ||
( | ||
(isinstance(decorator, ast.Name) and decorator.id == "overload") | ||
or (isinstance(decorator, ast.Attribute) and decorator.attr == "overload") | ||
) | ||
for decorator in node.decorator_list | ||
) | ||
|
||
|
||
def returns_typeguard(node: ast.FunctionDef) -> bool: | ||
def is_constant_typeguard(cst: ast.expr) -> bool: | ||
return isinstance(cst, ast.Constant) and cst.value is not None and "TypeGuard" in cst.value | ||
|
||
def is_subscript_typeguard(sub: ast.expr) -> bool: | ||
return ( | ||
isinstance(sub, ast.Subscript) | ||
and isinstance(sub.value, ast.Name) | ||
and "TypeGuard" in sub.value.id | ||
) | ||
|
||
return node.returns is not None and ( | ||
is_constant_typeguard(node.returns) or is_subscript_typeguard(node.returns) | ||
) | ||
|
||
|
||
def has_single_default_arg(node: ast.FunctionDef) -> bool: | ||
return len(node.args.args) == 1 and len(node.args.defaults) >= 1 | ||
|
||
|
||
def chainwrapper_args( | ||
node: ast.FunctionDef, | ||
) -> tuple[list[ast.expr], list[ast.keyword]]: | ||
# TODO: handle posonlyargs | ||
args: list[ast.expr] = [ast.Name(id=arg.arg) for arg in node.args.args[1:]] | ||
kwargs: list[ast.keyword] = [ | ||
ast.keyword(arg=kw.arg, value=ast.Name(id=kw.arg)) for kw in node.args.kwonlyargs | ||
] | ||
|
||
if node.args.vararg: | ||
args.append(ast.Starred(value=ast.Name(id=node.args.vararg.arg))) | ||
|
||
if node.args.kwarg: | ||
kwargs.append(ast.keyword(value=ast.Name(id=node.args.kwarg.arg))) | ||
|
||
return args, kwargs | ||
|
||
|
||
def wrap_type(wrapper: ast.Subscript, to_wrap: ast.expr) -> ast.expr: | ||
if isinstance(wrapper.slice, ast.Tuple): | ||
slice = ast.Tuple( | ||
elts=[ | ||
s if not (isinstance(s, ast.Name) and s.id == WRAPPER_KW) else to_wrap | ||
for s in wrapper.slice.elts | ||
] | ||
) | ||
else: | ||
slice = to_wrap | ||
|
||
return ast.Subscript( | ||
value=wrapper.value, | ||
slice=slice, | ||
) | ||
|
||
|
||
def transform_function(node: ast.FunctionDef, wrapper: ast.Subscript) -> ast.FunctionDef: | ||
first_arg = node.args.args[0] | ||
cw_args, cw_kwargs = chainwrapper_args(node) | ||
|
||
if first_arg.annotation: | ||
first_arg.annotation = ast.Constant( | ||
value=ast.unparse(wrap_type(wrapper, first_arg.annotation)) | ||
) | ||
|
||
first_arg.arg = "self" | ||
|
||
if node.returns: | ||
# TODO: `(some_arg: T) -> TypeGuard[T]` to `(some_arg: Any) -> bool` | ||
# TODO: otherwise we would get a `T` alone | ||
|
||
# change typeguard to bool as it is useless in a chain | ||
if returns_typeguard(node): | ||
node.returns = ast.Name(id="bool") | ||
|
||
node.returns = ast.Constant(value=ast.unparse(wrap_type(wrapper, node.returns))) | ||
|
||
if not is_overload(node): | ||
node.body = [ | ||
ast.Return( | ||
value=ast.Call( | ||
func=ast.Call( | ||
func=ast.Name(id="self._wrap"), | ||
args=[ast.Name(id=f"pyd.{node.name}")], | ||
keywords=[], | ||
), | ||
args=cw_args, | ||
keywords=cw_kwargs, | ||
) | ||
) | ||
] | ||
|
||
return node | ||
|
||
|
||
def filename_from_module(module: str) -> str: | ||
return "src/pydash/chaining/chaining.py" if module == "chaining" else f"src/pydash/{module}.py" | ||
|
||
|
||
def main() -> int: | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"class_name", | ||
help="Name of the output class to put typed methods in", | ||
) | ||
parser.add_argument( | ||
"output", | ||
type=Path, | ||
help="Path to the file to write the typed class to (probably a `.pyi` file)", | ||
) | ||
parser.add_argument( | ||
"wrapper", | ||
help="The main generic class (eg. `Chain`)", | ||
) | ||
parser.add_argument( | ||
"--imports", | ||
nargs="+", | ||
help="List of imports to add to the file", | ||
) | ||
args = parser.parse_args() | ||
|
||
wrapper = args.wrapper + f"[{WRAPPER_KW}]" | ||
wrapper = ast.parse(wrapper).body[0] | ||
assert isinstance(wrapper, ast.Expr), "`wrapper` value should contain one expression" | ||
wrapper = wrapper.value | ||
assert isinstance( | ||
wrapper, ast.Subscript | ||
), "`wrapper` value should contain one with one subscript" | ||
|
||
to_file = open(args.output, "w") | ||
to_file.write(build_header(args.class_name, args.imports or [])) | ||
|
||
module_to_funcs = modules_and_api_funcs() | ||
|
||
for module in module_to_funcs.keys(): | ||
filename = filename_from_module(module) | ||
|
||
with open(filename, encoding="utf-8") as source: | ||
tree = ast.parse(source.read(), filename=filename) | ||
|
||
class_methods = deque() | ||
|
||
for node in ast.walk(tree): | ||
if isinstance(node, ast.ClassDef): | ||
class_methods.extend(f for f in node.body if isinstance(f, ast.FunctionDef)) | ||
|
||
# skipping class methods | ||
if node in class_methods: | ||
class_methods.popleft() | ||
continue | ||
|
||
if ( | ||
isinstance(node, ast.FunctionDef) | ||
and node.name in module_to_funcs[module] | ||
and node.args.args # skipping funcs without args for now | ||
and not has_single_default_arg(node) # skipping 1 default arg funcs | ||
): | ||
new_node = transform_function(node, wrapper) | ||
to_file.write(" " * 4) | ||
to_file.write(ast.unparse(new_node).replace("\n", f"\n{' ' * 4}")) | ||
to_file.write("\n\n") | ||
if new_node.name.endswith("_") and not is_overload(new_node): | ||
to_file.write(f"{' ' * 4}{new_node.name.rstrip('_')} = {new_node.name}") | ||
to_file.write("\n\n") | ||
|
||
to_file.close() | ||
return 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
raise SystemExit(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from .chaining import _Dash, chain, tap, thru | ||
|
||
|
||
__all__ = ( | ||
"_Dash", | ||
"chain", | ||
"tap", | ||
"thru", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from abc import ABC, abstractmethod | ||
import typing as t | ||
|
||
|
||
class AllFuncs(ABC): | ||
"""Exposing all of the exposed functions of a module through an class.""" | ||
|
||
module: t.Any | ||
invalid_method_exception: t.Type[Exception] | ||
|
||
@abstractmethod | ||
def _wrap(self, func) -> t.Callable: | ||
"""Proxy attribute access to :attr:`module`.""" | ||
raise NotImplementedError() # pragma: no cover | ||
|
||
@classmethod | ||
def get_method(cls, name: str) -> t.Callable: | ||
""" | ||
Return valid :attr:`module` method. | ||
Args: | ||
name (str): Name of pydash method to get. | ||
Returns: | ||
function: :attr:`module` callable. | ||
Raises: | ||
InvalidMethod: Raised if `name` is not a valid :attr:`module` method. | ||
""" | ||
method = getattr(cls.module, name, None) | ||
|
||
if not callable(method) and not name.endswith("_"): | ||
# Alias method names not ending in underscore to their underscore | ||
# counterpart. This allows chaining of functions like "map_()" | ||
# using "map()" instead. | ||
method = getattr(cls.module, name + "_", None) | ||
|
||
if not callable(method): | ||
raise cls.invalid_method_exception(f"Invalid {cls.module.__name__} method: {name}") | ||
|
||
return method | ||
|
||
def __getattr__(self, name: str) -> t.Callable: | ||
return self._wrap(self.get_method(name)) |
Oops, something went wrong.