Skip to content

Commit

Permalink
feat: Support loading (and merging) *.pyi files
Browse files Browse the repository at this point in the history
  • Loading branch information
pawamoy committed May 6, 2022
1 parent 190585d commit 41518f4
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 13 deletions.
13 changes: 6 additions & 7 deletions src/griffe/agents/visitor.py
Expand Up @@ -10,7 +10,6 @@

import ast
import inspect
from collections import defaultdict
from contextlib import suppress
from itertools import zip_longest
from pathlib import Path
Expand Down Expand Up @@ -146,7 +145,6 @@ def __init__(
self.lines_collection: LinesCollection = lines_collection or LinesCollection()
self.modules_collection: ModulesCollection = modules_collection or ModulesCollection()
self.type_guarded: bool = False
self.overloads: dict[str, list[Function]] = defaultdict(list)

def _get_docstring(self, node: ast.AST, strict: bool = False) -> Docstring | None:
value, lineno, endlineno = get_docstring(node, strict=strict)
Expand Down Expand Up @@ -228,7 +226,7 @@ def visit_classdef(self, node: ast.ClassDef) -> None:
for decorator_node in node.decorator_list:
decorators.append(
Decorator(
get_value(decorator_node),
get_value(decorator_node), # type: ignore[arg-type]
lineno=decorator_node.lineno,
endlineno=decorator_node.end_lineno, # type: ignore[attr-defined]
)
Expand Down Expand Up @@ -326,7 +324,7 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels:
)
decorators.append(
Decorator(
decorator_value,
decorator_value, # type: ignore[arg-type]
lineno=decorator_node.lineno,
endlineno=decorator_node.end_lineno, # type: ignore[attr-defined]
)
Expand Down Expand Up @@ -423,7 +421,7 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels:
)

if overload:
self.overloads[function.path].append(function)
self.current.overloads[function.name].append(function)
elif base_property is not None:
if property_function == "setter":
base_property.setter = function
Expand All @@ -433,8 +431,9 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels:
base_property.labels.add("deletable")
else:
self.current[node.name] = function
if self.overloads[function.path]:
function.overloads = self.overloads[function.path]
if self.current.overloads[function.name]:
function.overloads = self.current.overloads[function.name]
del self.current.overloads[function.name] # noqa: WPS420

function.labels |= labels

Expand Down
3 changes: 3 additions & 0 deletions src/griffe/dataclasses.py
Expand Up @@ -9,6 +9,7 @@
import enum
import inspect
import sys
from collections import defaultdict
from contextlib import suppress
from pathlib import Path
from textwrap import dedent
Expand Down Expand Up @@ -1011,6 +1012,7 @@ def __init__(self, *args: Any, filepath: Path | list[Path] | None = None, **kwar
"""
super().__init__(*args, **kwargs)
self._filepath: Path | list[Path] | None = filepath
self.overloads: dict[str, list[Function]] = defaultdict(list)

def __repr__(self) -> str:
try:
Expand Down Expand Up @@ -1131,6 +1133,7 @@ def __init__(
super().__init__(*args, **kwargs)
self.bases: list[Name | Expression | str] = bases or []
self.decorators: list[Decorator] = decorators or []
self.overloads: dict[str, list[Function]] = defaultdict(list)

@property
def parameters(self) -> Parameters:
Expand Down
2 changes: 1 addition & 1 deletion src/griffe/encoders.py
Expand Up @@ -177,7 +177,7 @@ def _load_attribute(obj_dict: dict[str, Any]) -> Attribute:
lineno=obj_dict["lineno"],
endlineno=obj_dict.get("endlineno", None),
docstring=_load_docstring(obj_dict),
value=obj_dict["value"],
value=obj_dict.get("value", None),
annotation=_load_annotation(obj_dict.get("annotation", None)),
)
attribute.labels |= set(obj_dict.get("labels", ()))
Expand Down
4 changes: 2 additions & 2 deletions src/griffe/finder.py
Expand Up @@ -42,7 +42,7 @@ def is_namespace(self) -> bool:
class ModuleFinder:
"""The Griffe finder, allowing to find modules on the file system."""

accepted_py_module_extensions = [".py", ".pyc", ".pyo", ".pyd", ".so"]
accepted_py_module_extensions = [".py", ".pyc", ".pyo", ".pyd", ".pyi", ".so"]
extensions_set = set(accepted_py_module_extensions)

def __init__(self, search_paths: Sequence[str | Path] | None = None) -> None:
Expand Down Expand Up @@ -166,7 +166,7 @@ def iter_submodules(self, path: Path | list[Path]) -> Iterator[NamePartsAndPathT

if path.stem == "__init__":
path = path.parent
# optimization: just check if the file name ends with .py[cod]/.so
# optimization: just check if the file name ends with .py[icod]/.so
# (to distinguish it from a directory),
# not if it's an actual file
elif path.suffix in self.extensions_set:
Expand Down
2 changes: 1 addition & 1 deletion src/griffe/loader.py
Expand Up @@ -363,7 +363,7 @@ def _load_module_path(
logger.debug(f"Loading path {module_path}")
if isinstance(module_path, list):
module = self._create_module(module_name, module_path)
elif module_path.suffix == ".py":
elif module_path.suffix in {".py", ".pyi"}:
code = module_path.read_text(encoding="utf8")
module = self._visit_module(code, module_name, module_path, parent)
elif self.allow_inspection:
Expand Down
93 changes: 93 additions & 0 deletions src/griffe/merger.py
@@ -0,0 +1,93 @@
"""This module contains utilities to merge data together."""

from __future__ import annotations

from contextlib import suppress
from typing import TYPE_CHECKING

from griffe.logger import get_logger

if TYPE_CHECKING:
from griffe.dataclasses import Attribute, Class, Function, Module, Object


logger = get_logger(__name__)


def _merge_module_stubs(module: Module, stubs: Module) -> None:
_merge_stubs_docstring(module, stubs)
_merge_stubs_overloads(module, stubs)
_merge_stubs_members(module, stubs)


def _merge_class_stubs(class_: Class, stubs: Class) -> None:
_merge_stubs_docstring(class_, stubs)
_merge_stubs_overloads(class_, stubs)
_merge_stubs_members(class_, stubs)


def _merge_function_stubs(function: Function, stubs: Function) -> None:
_merge_stubs_docstring(function, stubs)
for parameter in stubs.parameters:
with suppress(KeyError):
function.parameters[parameter.name].annotation = parameter.annotation
function.returns = stubs.returns


def _merge_attribute_stubs(attribute: Attribute, stubs: Attribute) -> None:
_merge_stubs_docstring(attribute, stubs)
attribute.annotation = stubs.annotation


def _merge_stubs_docstring(obj: Object, stubs: Object) -> None:
if not obj.docstring and stubs.docstring:
obj.docstring = stubs.docstring


def _merge_stubs_overloads(obj: Module | Class, stubs: Module | Class) -> None:
for function_name, overloads in list(stubs.overloads.items()):
with suppress(KeyError):
obj[function_name].overloads = overloads
del stubs.overloads[function_name] # noqa: WPS420


def _merge_stubs_members(obj: Module | Class, stubs: Module | Class) -> None: # noqa: WPS231
for member_name, stub_member in stubs.members.items():
if member_name in obj.members:
obj_member = obj[member_name]
if obj_member.kind is not stub_member.kind:
logger.debug(f"Cannot merge stubs of kind {stub_member.kind} into object of kind {obj_member.kind}")
elif obj_member.is_class:
_merge_class_stubs(obj_member, stub_member) # type: ignore[arg-type]
elif obj_member.is_function:
_merge_function_stubs(obj_member, stub_member) # type: ignore[arg-type]
elif obj_member.is_attribute:
_merge_attribute_stubs(obj_member, stub_member) # type: ignore[arg-type]
else:
stub_member.runtime = False
obj[member_name] = stub_member


def merge_stubs(mod1: Module, mod2: Module) -> Module:
"""Merge stubs into a module.
Parameters:
mod1: A regular module or stubs module.
mod2: A regular module or stubs module.
Raises:
ValueError: When both modules are regular modules (no stubs is passed).
Returns:
The regular module.
"""
if mod1.filepath.suffix == ".pyi": # type: ignore[union-attr]
stubs = mod1
module = mod2
elif mod2.filepath.suffix == ".pyi": # type: ignore[union-attr]
stubs = mod2
module = mod1
else:
raise ValueError("cannot merge regular (non-stubs) modules together")
_merge_module_stubs(module, stubs)
return module
15 changes: 14 additions & 1 deletion src/griffe/mixins.py
Expand Up @@ -2,8 +2,14 @@

from __future__ import annotations

from contextlib import suppress
from typing import Any, Sequence

from griffe.logger import get_logger
from griffe.merger import merge_stubs

logger = get_logger(__name__)


class GetMembersMixin:
"""This mixin adds a `__getitem__` method to a class.
Expand Down Expand Up @@ -51,13 +57,20 @@ class SetMembersMixin(DelMembersMixin):
Each time a member is set, its `parent` attribute is set as well.
"""

def __setitem__(self, key: str | Sequence[str], value) -> None:
def __setitem__(self, key: str | Sequence[str], value) -> None: # noqa: WPS231
parts = _get_parts(key)
if len(parts) == 1:
name = parts[0]
if name in self.members: # type: ignore[attr-defined]
member = self.members[name] # type: ignore[attr-defined]
if not member.is_alias:
# when reassigning a module to an existing one,
# try to merge them as one regular and one stubs module
# (implicit support for .pyi modules)
if member.is_module and value.is_module:
logger.debug(f"Trying to merge {member.filepath} and {value.filepath}")
with suppress(ValueError):
value = merge_stubs(member, value)
for alias in member.aliases.values():
alias.target = value
self.members[name] = value # type: ignore[attr-defined]
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers.py
Expand Up @@ -64,7 +64,7 @@ def temporary_pypackage(package: str, modules: list[str] | None = None) -> Itera
for module in modules:
current_path = package_path
for part in Path(module).parts:
if part.endswith(".py"):
if part.endswith(".py") or part.endswith(".pyi"):
(current_path / part).touch()
else:
current_path /= part
Expand Down
114 changes: 114 additions & 0 deletions tests/test_loader.py
@@ -1,5 +1,7 @@
"""Tests for the `loader` module."""

from textwrap import dedent

from griffe.expressions import Name
from griffe.loader import GriffeLoader
from tests.helpers import temporary_pyfile, temporary_pypackage
Expand Down Expand Up @@ -80,3 +82,115 @@ def test_dont_overwrite_lower_member_when_expanding_wildcard():
loader.resolve_aliases()
assert package["mod_a.overwritten"].value == "1"
assert package["mod_a.not_overwritten"].value == "0"


def test_load_data_from_stubs():
"""Check that the loader is able to load data from stubs / `*.pyi` files."""
with temporary_pypackage("package", ["_rust_notify.pyi"]) as tmp_package:
# code taken from samuelcolvin/watchfiles project
code = '''
from typing import List, Literal, Optional, Protocol, Set, Tuple, Union
__all__ = 'RustNotify', 'WatchfilesRustInternalError'
class AbstractEvent(Protocol):
def is_set(self) -> bool: ...
class RustNotify:
"""
Interface to the Rust [notify](https://crates.io/crates/notify) crate which does
the heavy lifting of watching for file changes and grouping them into a single event.
"""
def __init__(self, watch_paths: List[str], debug: bool) -> None:
"""
Create a new RustNotify instance and start a thread to watch for changes.
`FileNotFoundError` is raised if one of the paths does not exist.
Args:
watch_paths: file system paths to watch for changes, can be directories or files
debug: if true, print details about all events to stderr
"""
'''
tmp_package.path.joinpath("_rust_notify.pyi").write_text(dedent(code))
tmp_package.path.joinpath("__init__.py").write_text(
"from ._rust_notify import RustNotify\n__all__ = ['RustNotify']"
)
loader = GriffeLoader(search_paths=[tmp_package.tmpdir])
package = loader.load_module(tmp_package.name)
loader.resolve_aliases()

assert "_rust_notify" in package.members
assert "RustNotify" in package.members
assert package["RustNotify"].resolved


def test_load_from_both_py_and_pyi_files():
"""Check that the loader is able to merge data loaded from `*.py` and `*.pyi` files."""
with temporary_pypackage("package", ["mod.py", "mod.pyi"]) as tmp_package:
tmp_package.path.joinpath("mod.py").write_text(
dedent(
"""
CONST = 0
class Class:
class_attr = True
def function1(self, arg1):
pass
def function2(self, arg1=2.2):
pass
"""
)
)
tmp_package.path.joinpath("mod.pyi").write_text(
dedent(
"""
from typing import Sequence, overload
CONST: int
class Class:
class_attr: bool
@overload
def function1(self, arg1: str) -> Sequence[str]: ...
@overload
def function1(self, arg1: bytes) -> Sequence[bytes]: ...
def function2(self, arg1: float) -> float: ...
"""
)
)
loader = GriffeLoader(search_paths=[tmp_package.tmpdir])
package = loader.load_module(tmp_package.name)
loader.resolve_aliases()

assert "mod" in package.members
mod = package["mod"]
assert mod.filepath.suffix == ".py"

assert "CONST" in mod.members
const = mod["CONST"]
assert const.value == "0"
assert const.annotation.source == "int"

assert "Class" in mod.members
class_ = mod["Class"]

assert "class_attr" in class_.members
class_attr = class_["class_attr"]
assert class_attr.value == "True"
assert class_attr.annotation.source == "bool"

assert "function1" in class_.members
function1 = class_["function1"]
assert len(function1.overloads) == 2

assert "function2" in class_.members
function2 = class_["function2"]
assert function2.returns.source == "float"
assert function2.parameters["arg1"].annotation.source == "float"
assert function2.parameters["arg1"].default == "2.2"

0 comments on commit 41518f4

Please sign in to comment.