From 41518f4aa9e00756a910067cf6f01f07ca7327da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Mazzucotelli?= Date: Fri, 6 May 2022 10:32:33 +0200 Subject: [PATCH] feat: Support loading (and merging) `*.pyi` files Issue mkdocstrings/mkdocstrings#404: https://github.com/mkdocstrings/mkdocstrings/issues/404 --- src/griffe/agents/visitor.py | 13 ++-- src/griffe/dataclasses.py | 3 + src/griffe/encoders.py | 2 +- src/griffe/finder.py | 4 +- src/griffe/loader.py | 2 +- src/griffe/merger.py | 93 ++++++++++++++++++++++++++++ src/griffe/mixins.py | 15 ++++- tests/helpers.py | 2 +- tests/test_loader.py | 114 +++++++++++++++++++++++++++++++++++ 9 files changed, 235 insertions(+), 13 deletions(-) create mode 100644 src/griffe/merger.py diff --git a/src/griffe/agents/visitor.py b/src/griffe/agents/visitor.py index 5677b0b9..8aa8f307 100644 --- a/src/griffe/agents/visitor.py +++ b/src/griffe/agents/visitor.py @@ -10,7 +10,6 @@ import ast import inspect -from collections import defaultdict from contextlib import suppress from itertools import zip_longest from pathlib import Path @@ -146,7 +145,6 @@ def __init__( self.lines_collection: LinesCollection = lines_collection or LinesCollection() self.modules_collection: ModulesCollection = modules_collection or ModulesCollection() self.type_guarded: bool = False - self.overloads: dict[str, list[Function]] = defaultdict(list) def _get_docstring(self, node: ast.AST, strict: bool = False) -> Docstring | None: value, lineno, endlineno = get_docstring(node, strict=strict) @@ -228,7 +226,7 @@ def visit_classdef(self, node: ast.ClassDef) -> None: for decorator_node in node.decorator_list: decorators.append( Decorator( - get_value(decorator_node), + get_value(decorator_node), # type: ignore[arg-type] lineno=decorator_node.lineno, endlineno=decorator_node.end_lineno, # type: ignore[attr-defined] ) @@ -326,7 +324,7 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels: ) decorators.append( Decorator( - decorator_value, + decorator_value, # type: ignore[arg-type] lineno=decorator_node.lineno, endlineno=decorator_node.end_lineno, # type: ignore[attr-defined] ) @@ -423,7 +421,7 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels: ) if overload: - self.overloads[function.path].append(function) + self.current.overloads[function.name].append(function) elif base_property is not None: if property_function == "setter": base_property.setter = function @@ -433,8 +431,9 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels: base_property.labels.add("deletable") else: self.current[node.name] = function - if self.overloads[function.path]: - function.overloads = self.overloads[function.path] + if self.current.overloads[function.name]: + function.overloads = self.current.overloads[function.name] + del self.current.overloads[function.name] # noqa: WPS420 function.labels |= labels diff --git a/src/griffe/dataclasses.py b/src/griffe/dataclasses.py index 82045b56..7610ab77 100644 --- a/src/griffe/dataclasses.py +++ b/src/griffe/dataclasses.py @@ -9,6 +9,7 @@ import enum import inspect import sys +from collections import defaultdict from contextlib import suppress from pathlib import Path from textwrap import dedent @@ -1011,6 +1012,7 @@ def __init__(self, *args: Any, filepath: Path | list[Path] | None = None, **kwar """ super().__init__(*args, **kwargs) self._filepath: Path | list[Path] | None = filepath + self.overloads: dict[str, list[Function]] = defaultdict(list) def __repr__(self) -> str: try: @@ -1131,6 +1133,7 @@ def __init__( super().__init__(*args, **kwargs) self.bases: list[Name | Expression | str] = bases or [] self.decorators: list[Decorator] = decorators or [] + self.overloads: dict[str, list[Function]] = defaultdict(list) @property def parameters(self) -> Parameters: diff --git a/src/griffe/encoders.py b/src/griffe/encoders.py index 77395ce0..d9dc2dd2 100644 --- a/src/griffe/encoders.py +++ b/src/griffe/encoders.py @@ -177,7 +177,7 @@ def _load_attribute(obj_dict: dict[str, Any]) -> Attribute: lineno=obj_dict["lineno"], endlineno=obj_dict.get("endlineno", None), docstring=_load_docstring(obj_dict), - value=obj_dict["value"], + value=obj_dict.get("value", None), annotation=_load_annotation(obj_dict.get("annotation", None)), ) attribute.labels |= set(obj_dict.get("labels", ())) diff --git a/src/griffe/finder.py b/src/griffe/finder.py index 6e6abbf9..ae5e3648 100644 --- a/src/griffe/finder.py +++ b/src/griffe/finder.py @@ -42,7 +42,7 @@ def is_namespace(self) -> bool: class ModuleFinder: """The Griffe finder, allowing to find modules on the file system.""" - accepted_py_module_extensions = [".py", ".pyc", ".pyo", ".pyd", ".so"] + accepted_py_module_extensions = [".py", ".pyc", ".pyo", ".pyd", ".pyi", ".so"] extensions_set = set(accepted_py_module_extensions) def __init__(self, search_paths: Sequence[str | Path] | None = None) -> None: @@ -166,7 +166,7 @@ def iter_submodules(self, path: Path | list[Path]) -> Iterator[NamePartsAndPathT if path.stem == "__init__": path = path.parent - # optimization: just check if the file name ends with .py[cod]/.so + # optimization: just check if the file name ends with .py[icod]/.so # (to distinguish it from a directory), # not if it's an actual file elif path.suffix in self.extensions_set: diff --git a/src/griffe/loader.py b/src/griffe/loader.py index 1ec82ace..5b13fadd 100644 --- a/src/griffe/loader.py +++ b/src/griffe/loader.py @@ -363,7 +363,7 @@ def _load_module_path( logger.debug(f"Loading path {module_path}") if isinstance(module_path, list): module = self._create_module(module_name, module_path) - elif module_path.suffix == ".py": + elif module_path.suffix in {".py", ".pyi"}: code = module_path.read_text(encoding="utf8") module = self._visit_module(code, module_name, module_path, parent) elif self.allow_inspection: diff --git a/src/griffe/merger.py b/src/griffe/merger.py new file mode 100644 index 00000000..31746696 --- /dev/null +++ b/src/griffe/merger.py @@ -0,0 +1,93 @@ +"""This module contains utilities to merge data together.""" + +from __future__ import annotations + +from contextlib import suppress +from typing import TYPE_CHECKING + +from griffe.logger import get_logger + +if TYPE_CHECKING: + from griffe.dataclasses import Attribute, Class, Function, Module, Object + + +logger = get_logger(__name__) + + +def _merge_module_stubs(module: Module, stubs: Module) -> None: + _merge_stubs_docstring(module, stubs) + _merge_stubs_overloads(module, stubs) + _merge_stubs_members(module, stubs) + + +def _merge_class_stubs(class_: Class, stubs: Class) -> None: + _merge_stubs_docstring(class_, stubs) + _merge_stubs_overloads(class_, stubs) + _merge_stubs_members(class_, stubs) + + +def _merge_function_stubs(function: Function, stubs: Function) -> None: + _merge_stubs_docstring(function, stubs) + for parameter in stubs.parameters: + with suppress(KeyError): + function.parameters[parameter.name].annotation = parameter.annotation + function.returns = stubs.returns + + +def _merge_attribute_stubs(attribute: Attribute, stubs: Attribute) -> None: + _merge_stubs_docstring(attribute, stubs) + attribute.annotation = stubs.annotation + + +def _merge_stubs_docstring(obj: Object, stubs: Object) -> None: + if not obj.docstring and stubs.docstring: + obj.docstring = stubs.docstring + + +def _merge_stubs_overloads(obj: Module | Class, stubs: Module | Class) -> None: + for function_name, overloads in list(stubs.overloads.items()): + with suppress(KeyError): + obj[function_name].overloads = overloads + del stubs.overloads[function_name] # noqa: WPS420 + + +def _merge_stubs_members(obj: Module | Class, stubs: Module | Class) -> None: # noqa: WPS231 + for member_name, stub_member in stubs.members.items(): + if member_name in obj.members: + obj_member = obj[member_name] + if obj_member.kind is not stub_member.kind: + logger.debug(f"Cannot merge stubs of kind {stub_member.kind} into object of kind {obj_member.kind}") + elif obj_member.is_class: + _merge_class_stubs(obj_member, stub_member) # type: ignore[arg-type] + elif obj_member.is_function: + _merge_function_stubs(obj_member, stub_member) # type: ignore[arg-type] + elif obj_member.is_attribute: + _merge_attribute_stubs(obj_member, stub_member) # type: ignore[arg-type] + else: + stub_member.runtime = False + obj[member_name] = stub_member + + +def merge_stubs(mod1: Module, mod2: Module) -> Module: + """Merge stubs into a module. + + Parameters: + mod1: A regular module or stubs module. + mod2: A regular module or stubs module. + + Raises: + ValueError: When both modules are regular modules (no stubs is passed). + + Returns: + The regular module. + """ + if mod1.filepath.suffix == ".pyi": # type: ignore[union-attr] + stubs = mod1 + module = mod2 + elif mod2.filepath.suffix == ".pyi": # type: ignore[union-attr] + stubs = mod2 + module = mod1 + else: + raise ValueError("cannot merge regular (non-stubs) modules together") + _merge_module_stubs(module, stubs) + return module diff --git a/src/griffe/mixins.py b/src/griffe/mixins.py index 5e6fbd59..7775f680 100644 --- a/src/griffe/mixins.py +++ b/src/griffe/mixins.py @@ -2,8 +2,14 @@ from __future__ import annotations +from contextlib import suppress from typing import Any, Sequence +from griffe.logger import get_logger +from griffe.merger import merge_stubs + +logger = get_logger(__name__) + class GetMembersMixin: """This mixin adds a `__getitem__` method to a class. @@ -51,13 +57,20 @@ class SetMembersMixin(DelMembersMixin): Each time a member is set, its `parent` attribute is set as well. """ - def __setitem__(self, key: str | Sequence[str], value) -> None: + def __setitem__(self, key: str | Sequence[str], value) -> None: # noqa: WPS231 parts = _get_parts(key) if len(parts) == 1: name = parts[0] if name in self.members: # type: ignore[attr-defined] member = self.members[name] # type: ignore[attr-defined] if not member.is_alias: + # when reassigning a module to an existing one, + # try to merge them as one regular and one stubs module + # (implicit support for .pyi modules) + if member.is_module and value.is_module: + logger.debug(f"Trying to merge {member.filepath} and {value.filepath}") + with suppress(ValueError): + value = merge_stubs(member, value) for alias in member.aliases.values(): alias.target = value self.members[name] = value # type: ignore[attr-defined] diff --git a/tests/helpers.py b/tests/helpers.py index 73b0f7ca..1570f223 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -64,7 +64,7 @@ def temporary_pypackage(package: str, modules: list[str] | None = None) -> Itera for module in modules: current_path = package_path for part in Path(module).parts: - if part.endswith(".py"): + if part.endswith(".py") or part.endswith(".pyi"): (current_path / part).touch() else: current_path /= part diff --git a/tests/test_loader.py b/tests/test_loader.py index 88da8bf3..40d6521f 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -1,5 +1,7 @@ """Tests for the `loader` module.""" +from textwrap import dedent + from griffe.expressions import Name from griffe.loader import GriffeLoader from tests.helpers import temporary_pyfile, temporary_pypackage @@ -80,3 +82,115 @@ def test_dont_overwrite_lower_member_when_expanding_wildcard(): loader.resolve_aliases() assert package["mod_a.overwritten"].value == "1" assert package["mod_a.not_overwritten"].value == "0" + + +def test_load_data_from_stubs(): + """Check that the loader is able to load data from stubs / `*.pyi` files.""" + with temporary_pypackage("package", ["_rust_notify.pyi"]) as tmp_package: + # code taken from samuelcolvin/watchfiles project + code = ''' + from typing import List, Literal, Optional, Protocol, Set, Tuple, Union + + __all__ = 'RustNotify', 'WatchfilesRustInternalError' + + class AbstractEvent(Protocol): + def is_set(self) -> bool: ... + + class RustNotify: + """ + Interface to the Rust [notify](https://crates.io/crates/notify) crate which does + the heavy lifting of watching for file changes and grouping them into a single event. + """ + + def __init__(self, watch_paths: List[str], debug: bool) -> None: + """ + Create a new RustNotify instance and start a thread to watch for changes. + + `FileNotFoundError` is raised if one of the paths does not exist. + + Args: + watch_paths: file system paths to watch for changes, can be directories or files + debug: if true, print details about all events to stderr + """ + ''' + tmp_package.path.joinpath("_rust_notify.pyi").write_text(dedent(code)) + tmp_package.path.joinpath("__init__.py").write_text( + "from ._rust_notify import RustNotify\n__all__ = ['RustNotify']" + ) + loader = GriffeLoader(search_paths=[tmp_package.tmpdir]) + package = loader.load_module(tmp_package.name) + loader.resolve_aliases() + + assert "_rust_notify" in package.members + assert "RustNotify" in package.members + assert package["RustNotify"].resolved + + +def test_load_from_both_py_and_pyi_files(): + """Check that the loader is able to merge data loaded from `*.py` and `*.pyi` files.""" + with temporary_pypackage("package", ["mod.py", "mod.pyi"]) as tmp_package: + tmp_package.path.joinpath("mod.py").write_text( + dedent( + """ + CONST = 0 + + class Class: + class_attr = True + + def function1(self, arg1): + pass + + def function2(self, arg1=2.2): + pass + """ + ) + ) + tmp_package.path.joinpath("mod.pyi").write_text( + dedent( + """ + from typing import Sequence, overload + + CONST: int + + class Class: + class_attr: bool + + @overload + def function1(self, arg1: str) -> Sequence[str]: ... + @overload + def function1(self, arg1: bytes) -> Sequence[bytes]: ... + + def function2(self, arg1: float) -> float: ... + """ + ) + ) + loader = GriffeLoader(search_paths=[tmp_package.tmpdir]) + package = loader.load_module(tmp_package.name) + loader.resolve_aliases() + + assert "mod" in package.members + mod = package["mod"] + assert mod.filepath.suffix == ".py" + + assert "CONST" in mod.members + const = mod["CONST"] + assert const.value == "0" + assert const.annotation.source == "int" + + assert "Class" in mod.members + class_ = mod["Class"] + + assert "class_attr" in class_.members + class_attr = class_["class_attr"] + assert class_attr.value == "True" + assert class_attr.annotation.source == "bool" + + assert "function1" in class_.members + function1 = class_["function1"] + assert len(function1.overloads) == 2 + + assert "function2" in class_.members + function2 = class_["function2"] + assert function2.returns.source == "float" + assert function2.parameters["arg1"].annotation.source == "float" + assert function2.parameters["arg1"].default == "2.2"