Skip to content

Commit

Permalink
Merge 5e39935 into a0483e4
Browse files Browse the repository at this point in the history
  • Loading branch information
DeviousStoat committed Mar 7, 2023
2 parents a0483e4 + 5e39935 commit 8afa62d
Show file tree
Hide file tree
Showing 11 changed files with 3,352 additions and 69 deletions.
254 changes: 254 additions & 0 deletions scripts/chaining_type_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import ast
import argparse
from pathlib import Path
from collections import defaultdict, deque

import pydash


WRAPPER_KW = "RES"
INIT_FILE = "src/pydash/__init__.py"
BASE_MODULE = """
'''Generated from the `scripts/chaining_type_generator.py` script'''
import re
import typing as t
from typing_extensions import ParamSpec, Type
import pydash as pyd
from pydash.chaining.chaining import Chain
from pydash.types import *
from pydash.helpers import Unset, UNSET
from pydash.functions import After, Ary, Before, Once, Spread, Throttle
from _typeshed import (
SupportsDunderGE,
SupportsDunderGT,
SupportsDunderLE,
SupportsDunderLT,
SupportsRichComparison,
SupportsAdd,
SupportsRichComparisonT,
SupportsSub,
)
{imports}
Value_coT = t.TypeVar("Value_coT", covariant=True)
T = t.TypeVar("T")
T2 = t.TypeVar("T2")
T3 = t.TypeVar("T3")
T4 = t.TypeVar("T4")
T5 = t.TypeVar("T5")
NumT = t.TypeVar("NumT", int, float)
CallableT = t.TypeVar("CallableT", bound=t.Callable)
SequenceT = t.TypeVar("SequenceT", bound=t.Sequence)
MutableSequenceT = t.TypeVar("MutableSequenceT", bound=t.MutableSequence)
P = ParamSpec("P")
class {class_name}:
"""


def build_header(class_name: str, imports: list[str]) -> str:
return BASE_MODULE.format(class_name=class_name, imports="\n".join(imports))


def modules_and_api_funcs() -> dict[str, list[str]]:
"""This is mostly so we don't have to import `pydash`"""

with open(INIT_FILE, "r", encoding="utf-8") as source:
tree = ast.parse(source.read())

module_to_funcs = defaultdict(list)

for node in ast.walk(tree):
# TODO: maybe handle `Import` as well, not necessary for now
if isinstance(node, ast.ImportFrom):
for name in node.names:
module_to_funcs[node.module].append(name.asname or name.name)

return module_to_funcs


def is_overload(node: ast.FunctionDef) -> bool:
return any(
(
(isinstance(decorator, ast.Name) and decorator.id == "overload")
or (isinstance(decorator, ast.Attribute) and decorator.attr == "overload")
)
for decorator in node.decorator_list
)


def returns_typeguard(node: ast.FunctionDef) -> bool:
def is_constant_typeguard(cst: ast.expr) -> bool:
return isinstance(cst, ast.Constant) and cst.value is not None and "TypeGuard" in cst.value

def is_subscript_typeguard(sub: ast.expr) -> bool:
return (
isinstance(sub, ast.Subscript)
and isinstance(sub.value, ast.Name)
and "TypeGuard" in sub.value.id
)

return node.returns is not None and (
is_constant_typeguard(node.returns) or is_subscript_typeguard(node.returns)
)


def has_single_default_arg(node: ast.FunctionDef) -> bool:
return len(node.args.args) == 1 and len(node.args.defaults) >= 1


def chainwrapper_args(
node: ast.FunctionDef,
) -> tuple[list[ast.expr], list[ast.keyword]]:
# TODO: handle posonlyargs
args: list[ast.expr] = [ast.Name(id=arg.arg) for arg in node.args.args[1:]]
kwargs: list[ast.keyword] = [
ast.keyword(arg=kw.arg, value=ast.Name(id=kw.arg)) for kw in node.args.kwonlyargs
]

if node.args.vararg:
args.append(ast.Starred(value=ast.Name(id=node.args.vararg.arg)))

if node.args.kwarg:
kwargs.append(ast.keyword(value=ast.Name(id=node.args.kwarg.arg)))

return args, kwargs


def wrap_type(wrapper: ast.Subscript, to_wrap: ast.expr) -> ast.expr:
if isinstance(wrapper.slice, ast.Tuple):
slice = ast.Tuple(
elts=[
s if not (isinstance(s, ast.Name) and s.id == WRAPPER_KW) else to_wrap
for s in wrapper.slice.elts
]
)
else:
slice = to_wrap

return ast.Subscript(
value=wrapper.value,
slice=slice,
)


def transform_function(node: ast.FunctionDef, wrapper: ast.Subscript) -> ast.FunctionDef:
first_arg = node.args.args[0]
cw_args, cw_kwargs = chainwrapper_args(node)

if first_arg.annotation:
first_arg.annotation = ast.Constant(
value=ast.unparse(wrap_type(wrapper, first_arg.annotation))
)

first_arg.arg = "self"

if node.returns:
# TODO: `(some_arg: T) -> TypeGuard[T]` to `(some_arg: Any) -> bool`
# TODO: otherwise we would get a `T` alone

# change typeguard to bool as it is useless in a chain
if returns_typeguard(node):
node.returns = ast.Name(id="bool")

node.returns = ast.Constant(value=ast.unparse(wrap_type(wrapper, node.returns)))

if not is_overload(node):
node.body = [
ast.Return(
value=ast.Call(
func=ast.Call(
func=ast.Name(id="self._wrap"),
args=[ast.Name(id=f"pyd.{node.name}")],
keywords=[],
),
args=cw_args,
keywords=cw_kwargs,
)
)
]

return node


def filename_from_module(module: str) -> str:
return "src/pydash/chaining/chaining.py" if module == "chaining" else f"src/pydash/{module}.py"


def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument(
"class_name",
help="Name of the output class to put typed methods in",
)
parser.add_argument(
"output",
type=Path,
help="Path to the file to write the typed class to (probably a `.pyi` file)",
)
parser.add_argument(
"wrapper",
help="The main generic class (eg. `Chain`)",
)
parser.add_argument(
"--imports",
nargs="+",
help="List of imports to add to the file",
)
args = parser.parse_args()

wrapper = args.wrapper + f"[{WRAPPER_KW}]"
wrapper = ast.parse(wrapper).body[0]
assert isinstance(wrapper, ast.Expr), "`wrapper` value should contain one expression"
wrapper = wrapper.value
assert isinstance(
wrapper, ast.Subscript
), "`wrapper` value should contain one with one subscript"

to_file = open(args.output, "w")
to_file.write(build_header(args.class_name, args.imports or []))

module_to_funcs = modules_and_api_funcs()

for module in module_to_funcs.keys():
filename = filename_from_module(module)

with open(filename, encoding="utf-8") as source:
tree = ast.parse(source.read(), filename=filename)

class_methods = deque()

for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
class_methods.extend(f for f in node.body if isinstance(f, ast.FunctionDef))

# skipping class methods
if node in class_methods:
class_methods.popleft()
continue

if (
isinstance(node, ast.FunctionDef)
and node.name in module_to_funcs[module]
and node.args.args # skipping funcs without args for now
and not has_single_default_arg(node) # skipping 1 default arg funcs
):
new_node = transform_function(node, wrapper)
to_file.write(" " * 4)
to_file.write(ast.unparse(new_node).replace("\n", f"\n{' ' * 4}"))
to_file.write("\n\n")
if new_node.name.endswith("_") and not is_overload(new_node):
to_file.write(f"{' ' * 4}{new_node.name.rstrip('_')} = {new_node.name}")
to_file.write("\n\n")

to_file.close()
return 0


if __name__ == "__main__":
raise SystemExit(main())
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ strict_optional = True
warn_no_return = True
warn_redundant_casts = False
warn_unused_ignores = False
exclude = "*.pyi"

[tool:isort]
line_length = 100
Expand Down
9 changes: 9 additions & 0 deletions src/pydash/chaining/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .chaining import _Dash, chain, tap, thru


__all__ = (
"_Dash",
"chain",
"tap",
"thru",
)
44 changes: 44 additions & 0 deletions src/pydash/chaining/all_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
import typing as t


class AllFuncs(ABC):
"""Exposing all of the exposed functions of a module through an class."""

module: t.Any
invalid_method_exception: t.Type[Exception]

@abstractmethod
def _wrap(self, func) -> t.Callable:
"""Proxy attribute access to :attr:`module`."""
raise NotImplementedError() # pragma: no cover

@classmethod
def get_method(cls, name: str) -> t.Callable:
"""
Return valid :attr:`module` method.
Args:
name (str): Name of pydash method to get.
Returns:
function: :attr:`module` callable.
Raises:
InvalidMethod: Raised if `name` is not a valid :attr:`module` method.
"""
method = getattr(cls.module, name, None)

if not callable(method) and not name.endswith("_"):
# Alias method names not ending in underscore to their underscore
# counterpart. This allows chaining of functions like "map_()"
# using "map()" instead.
method = getattr(cls.module, name + "_", None)

if not callable(method):
raise cls.invalid_method_exception(f"Invalid {cls.module.__name__} method: {name}")

return method

def __getattr__(self, name: str) -> t.Callable:
return self._wrap(self.get_method(name))

0 comments on commit 8afa62d

Please sign in to comment.