Skip to content

Commit

Permalink
feat: Better support for dataclasses
Browse files Browse the repository at this point in the history
Instead of generating parameters on the fly by (wrongly) checking attributes of the class,
we always load a Griffe extension that re-creates `__init__` methods and their parameters.

Issue-33: #233
Issue-34: #234
Issue-38: #238
Issue-39: #239
PR-240: #240
  • Loading branch information
pawamoy committed Mar 5, 2024
1 parent 9efda88 commit 82a9d57
Show file tree
Hide file tree
Showing 9 changed files with 409 additions and 46 deletions.
4 changes: 2 additions & 2 deletions src/griffe/agents/inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from griffe.dataclasses import Alias, Attribute, Class, Docstring, Function, Module, Parameter, Parameters
from griffe.enumerations import ObjectKind, ParameterKind
from griffe.expressions import safe_get_annotation
from griffe.extensions.base import Extensions
from griffe.extensions.base import Extensions, load_extensions
from griffe.importer import dynamic_import

if TYPE_CHECKING:
Expand Down Expand Up @@ -77,7 +77,7 @@ def inspect(
return Inspector(
module_name,
filepath,
extensions or Extensions(),
extensions or load_extensions(),
parent,
docstring_parser=docstring_parser,
docstring_options=docstring_options,
Expand Down
4 changes: 2 additions & 2 deletions src/griffe/agents/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
safe_get_condition,
safe_get_expression,
)
from griffe.extensions.base import Extensions
from griffe.extensions.base import Extensions, load_extensions

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -92,7 +92,7 @@ def visit(
module_name,
filepath,
code,
extensions or Extensions(),
extensions or load_extensions(),
parent,
docstring_parser=docstring_parser,
docstring_options=docstring_options,
Expand Down
18 changes: 10 additions & 8 deletions src/griffe/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,16 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"Parameter(name={self.name!r}, annotation={self.annotation!r}, kind={self.kind!r}, default={self.default!r})"

def __eq__(self, __value: object) -> bool:
if not isinstance(__value, Parameter):
return NotImplemented
return (
self.name == __value.name
and self.annotation == __value.annotation
and self.kind == __value.kind
and self.default == __value.default
)

@property
def required(self) -> bool:
"""Whether this parameter is required."""
Expand Down Expand Up @@ -1561,14 +1571,6 @@ def parameters(self) -> Parameters:
try:
return self.all_members["__init__"].parameters # type: ignore[union-attr]
except KeyError:
if "dataclass" in self.labels:
return Parameters(
*[
Parameter(attr.name, annotation=attr.annotation, default=attr.value)
for attr in self.attributes.values()
if "property" not in attr.labels
],
)
return Parameters()

@cached_property
Expand Down
5 changes: 5 additions & 0 deletions src/griffe/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,11 @@ class ExprCall(Expr):
arguments: Sequence[str | Expr]
"""Passed arguments."""

@property
def canonical_path(self) -> str:
"""The canonical path of this subscript's left part."""
return self.function.canonical_path

def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]: # noqa: D102
yield from _yield(self.function, flat=flat)
yield "("
Expand Down
18 changes: 16 additions & 2 deletions src/griffe/extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def call(self, event: str, **kwargs: Any) -> None:

builtin_extensions: set[str] = {
"hybrid",
"dataclasses",
}


Expand Down Expand Up @@ -454,7 +455,9 @@ def _load_extension(
return [ext(**options) for ext in extensions]


def load_extensions(exts: Sequence[str | dict[str, Any] | ExtensionType | type[ExtensionType]]) -> Extensions:
def load_extensions(
exts: Sequence[str | dict[str, Any] | ExtensionType | type[ExtensionType]] | None = None,
) -> Extensions:
"""Load configured extensions.
Parameters:
Expand All @@ -464,12 +467,23 @@ def load_extensions(exts: Sequence[str | dict[str, Any] | ExtensionType | type[E
An extensions container.
"""
extensions = Extensions()
for extension in exts:
for extension in exts or ():
ext = _load_extension(extension)
if isinstance(ext, list):
extensions.add(*ext)
else:
extensions.add(ext)

# TODO: Deprecate and remove at some point?
# Always add our built-in dataclasses extension.
from griffe.extensions.dataclasses import DataclassesExtension

for ext in extensions._extensions:
if type(ext) == DataclassesExtension:
break
else:
extensions.add(*_load_extension("dataclasses")) # type: ignore[misc]

return extensions


Expand Down
204 changes: 204 additions & 0 deletions src/griffe/extensions/dataclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""Built-in extension adding support for dataclasses.
This extension re-creates `__init__` methods of dataclasses
during static analysis.
"""

from __future__ import annotations

import ast
from contextlib import suppress
from functools import lru_cache
from typing import Any, cast

from griffe.dataclasses import Attribute, Class, Decorator, Function, Module, Parameter, Parameters
from griffe.enumerations import ParameterKind
from griffe.expressions import (
Expr,
ExprAttribute,
ExprCall,
ExprDict,
)
from griffe.extensions.base import Extension


def _dataclass_decorator(decorators: list[Decorator]) -> Expr | None:
for decorator in decorators:
if isinstance(decorator.value, Expr) and decorator.value.canonical_path == "dataclasses.dataclass":
return decorator.value
return None


def _expr_args(expr: Expr) -> dict[str, str | Expr]:
args = {}
if isinstance(expr, ExprCall):
for argument in expr.arguments:
try:
args[argument.name] = argument.value # type: ignore[union-attr]
except AttributeError:
# Argument is a unpacked variable.
with suppress(Exception):
collection = expr.function.parent.modules_collection # type: ignore[attr-defined]
var = collection[argument.value.canonical_path] # type: ignore[union-attr]
args.update(_expr_args(var.value))
elif isinstance(expr, ExprDict):
args.update({ast.literal_eval(str(key)): value for key, value in zip(expr.keys, expr.values)})
return args


def _dataclass_arguments(decorators: list[Decorator]) -> dict[str, Any]:
if (expr := _dataclass_decorator(decorators)) and isinstance(expr, ExprCall):
return _expr_args(expr)
return {}


def _field_arguments(attribute: Attribute) -> dict[str, Any]:
if attribute.value:
value = attribute.value
if isinstance(value, ExprAttribute):
value = value.last
if isinstance(value, ExprCall) and value.canonical_path == "dataclasses.field":
return _expr_args(value)
return {}


@lru_cache(maxsize=None)
def _dataclass_parameters(class_: Class) -> list[Parameter]:
# Fetch `@dataclass` arguments if any.
dec_args = _dataclass_arguments(class_.decorators)

# Parameters not added to `__init__`, return empty list.
if dec_args.get("init") == "False":
return []

# All parameters marked as keyword-only.
kw_only = dec_args.get("kw_only") == "True"

# Iterate on current attributes to find parameters.
parameters = []
for member in class_.members.values():
if member.is_attribute:
member = cast(Attribute, member)

# Start of keyword-only parameters.
if isinstance(member.annotation, Expr) and member.annotation.canonical_path == "dataclasses.KW_ONLY":
kw_only = True
continue

# Fetch `field` arguments if any.
field_args = _field_arguments(member)

# Parameter not added to `__init__`, skip it.
if field_args.get("init") == "False":
continue

# Determine parameter kind.
kind = (
ParameterKind.keyword_only
if kw_only or field_args.get("kw_only") == "True"
else ParameterKind.positional_or_keyword
)

# Determine parameter default.
if "default_factory" in field_args:
default = ExprCall(function=field_args["default_factory"], arguments=[])
else:
default = field_args.get("default", None if field_args else member.value)

# Add parameter to the list.
parameters.append(
Parameter(
member.name,
annotation=member.annotation,
kind=kind,
default=default,
),
)

return parameters


def _reorder_parameters(parameters: list[Parameter]) -> list[Parameter]:
# De-duplicate, overwriting previous parameters.
params_dict = {param.name: param for param in parameters}

# Re-order, putting positional-only in front and keyword-only at the end.
pos_only = []
pos_kw = []
kw_only = []
for param in params_dict.values():
if param.kind is ParameterKind.positional_only:
pos_only.append(param)
elif param.kind is ParameterKind.keyword_only:
kw_only.append(param)
else:
pos_kw.append(param)
return pos_only + pos_kw + kw_only


def _set_dataclass_init(class_: Class) -> None:
# Retrieve parameters from all parent dataclasses.
parameters = []
try:
mro = class_.mro()
except ValueError:
mro = () # type: ignore[assignment]
for parent in reversed(mro):
if _dataclass_decorator(parent.decorators):
parameters.extend(_dataclass_parameters(parent))
# At least one parent dataclass makes the current class a dataclass:
# that's how `dataclasses.is_dataclass` works.
class_.labels.add("dataclass")

# If the class is not decorated with `@dataclass`, skip it.
if not _dataclass_decorator(class_.decorators):
return

# Add current class parameters.
parameters.extend(_dataclass_parameters(class_))

# Create `__init__` method with re-ordered parameters.
init = Function(
"__init__",
lineno=0,
endlineno=0,
parent=class_,
parameters=Parameters(
Parameter(name="self", annotation=None, kind=ParameterKind.positional_or_keyword, default=None),
*_reorder_parameters(parameters),
),
returns="None",
)
class_.set_member("__init__", init)


def _apply_recursively(mod_cls: Module | Class, processed: set[str]) -> None:
if mod_cls.canonical_path in processed:
return
processed.add(mod_cls.canonical_path)
if isinstance(mod_cls, Class):
if "__init__" not in mod_cls.members:
_set_dataclass_init(mod_cls)
for member in mod_cls.members.values():
if not member.is_alias and member.is_class:
_apply_recursively(member, processed) # type: ignore[arg-type]
elif isinstance(mod_cls, Module):
for member in mod_cls.members.values():
if not member.is_alias and (member.is_module or member.is_class):
_apply_recursively(member, processed) # type: ignore[arg-type]


class DataclassesExtension(Extension):
"""Built-in extension adding support for dataclasses.
This extension creates `__init__` methods of dataclasses
if they don't already exist.
"""

def on_package_loaded(self, *, pkg: Module) -> None:
"""Hook for loaded packages.
Parameters:
pkg: The loaded package.
"""
_apply_recursively(pkg, set())
4 changes: 2 additions & 2 deletions src/griffe/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from griffe.enumerations import Kind
from griffe.exceptions import AliasResolutionError, CyclicAliasError, LoadingError, UnimportableModuleError
from griffe.expressions import ExprName
from griffe.extensions.base import Extensions
from griffe.extensions.base import Extensions, load_extensions
from griffe.finder import ModuleFinder, NamespacePackage, Package
from griffe.git import tmp_worktree
from griffe.logger import get_logger
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
allow_inspection: Whether to allow inspecting modules when visiting them is not possible.
store_source: Whether to store code source in the lines collection.
"""
self.extensions: Extensions = extensions or Extensions()
self.extensions: Extensions = extensions or load_extensions()
"""Loaded Griffe extensions."""
self.docstring_parser: Parser | None = docstring_parser
"""Selected docstring parser."""
Expand Down
Loading

0 comments on commit 82a9d57

Please sign in to comment.