Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 187 additions & 61 deletions build_scripts/gen_api_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
API_JSON_DIR = Path("doc/_api")
API_MD_DIR = Path("doc/api")

# Modules excluded from generated API docs (internal implementation details)
EXCLUDED_MODULES = {
"pyrit.backend",
}


def render_params(params: list[dict]) -> str:
"""Render parameter list as a markdown table."""
Expand Down Expand Up @@ -88,6 +93,30 @@ def render_signature(member: dict) -> str:
return f"({sig})"


def _escape_docstring_examples(text: str) -> str:
"""Wrap doctest-style examples (>>> lines) in code fences."""
lines = text.split("\n")
result: list[str] = []
in_example = False
for line in lines:
stripped = line.strip()
if stripped.startswith(">>>") and not in_example:
in_example = True
result.append("```python")
result.append(line)
elif in_example and stripped.startswith((">>>", "...")):
result.append(line)
elif in_example:
result.append("```")
in_example = False
result.append(line)
else:
result.append(line)
if in_example:
result.append("```")
return "\n".join(result)


def render_function(func: dict, heading_level: str = "###") -> str:
"""Render a function as markdown."""
name = func["name"]
Expand All @@ -97,18 +126,13 @@ def render_function(func: dict, heading_level: str = "###") -> str:
ret = func.get("returns_annotation", "")
ret_str = f" → {ret}" if ret else ""

# Use heading for name, code block for full signature if long
full_sig = f"{prefix}{name}{sig}{ret_str}"
if len(full_sig) > 80:
parts = [f"{heading_level} {prefix}{name}\n"]
parts.append(f"```python\n{prefix}{name}{sig}{ret_str}\n```\n")
else:
parts = [f"{heading_level} `{full_sig}`\n"]
parts = [f"{heading_level} `{prefix}{name}`\n"]
parts.append(f"```python\n{prefix}{name}{sig}{ret_str}\n```\n")

ds = func.get("docstring", {})
if ds:
if ds.get("text"):
parts.append(ds["text"] + "\n")
parts.append(_escape_docstring_examples(ds["text"]) + "\n")
params_table = render_params(ds.get("params", []))
if params_table:
parts.append(params_table + "\n")
Expand All @@ -128,11 +152,13 @@ def render_class(cls: dict) -> str:
bases = cls.get("bases", [])
bases_str = f"({', '.join(bases)})" if bases else ""

parts = [f"## `class {name}{bases_str}`\n"]
parts = [f"## `{name}`\n"]
if bases_str:
parts.append(f"Bases: `{bases_str[1:-1]}`\n")

ds = cls.get("docstring", {})
if ds and ds.get("text"):
parts.append(ds["text"] + "\n")
parts.append(_escape_docstring_examples(ds["text"]) + "\n")

# __init__
init = cls.get("init")
Expand All @@ -151,6 +177,16 @@ def render_class(cls: dict) -> str:
return "\n".join(parts)


def render_alias(alias: dict) -> str:
"""Render an alias as markdown."""
name = alias["name"]
target = alias.get("target", "")
parts = [f"### `{name}`\n"]
if target:
parts.append(f"Alias of `{target}`.\n")
return "\n".join(parts)


def render_module(data: dict) -> str:
"""Render a full module page."""
mod_name = data["name"]
Expand All @@ -162,10 +198,8 @@ def render_module(data: dict) -> str:

members = data.get("members", [])

# Separate classes and functions
classes = [m for m in members if m.get("kind") == "class"]
functions = [m for m in members if m.get("kind") == "function"]
aliases = [m for m in members if m.get("kind") == "alias"]

if functions:
parts.append("## Functions\n")
Expand All @@ -176,89 +210,181 @@ def render_module(data: dict) -> str:
return "\n".join(parts)


def split_aggregate_json(api_json_dir: Path) -> None:
"""Split aggregate JSON files that contain nested submodules into individual files.
def _build_definition_index(
data: dict,
index: dict | None = None,
name_to_modules: dict[str, list[str]] | None = None,
) -> tuple[dict, dict[str, list[str]]]:
"""Build a flat lookup from fully-qualified name to member definition.

Also builds a reverse lookup mapping each short member name to the list of
module paths where it is defined, so imports can be distinguished from native
definitions.
"""
if index is None:
index = {}
if name_to_modules is None:
name_to_modules = {}
mod_name = data.get("name", "")
for member in data.get("members", []):
kind = member.get("kind", "")
name = member.get("name", "")
if kind in ("class", "function") and name:
fqn = f"{mod_name}.{name}" if mod_name else name
index[fqn] = member
name_to_modules.setdefault(name, []).append(mod_name)
if kind == "module":
_build_definition_index(member, index, name_to_modules)
return index, name_to_modules


def _resolve_aliases(modules: list[dict], definition_index: dict, name_to_modules: dict[str, list[str]]) -> None:
"""Replace bare alias entries with the full definition they point to.

Aliases whose targets resolve to a class or function in the definition index
are swapped in-place so they render with full documentation. Unresolvable
aliases that appear to reference a pyrit class (capitalized name with a
pyrit target) are kept as minimal class stubs. Aliases pointing outside the
pyrit namespace are dropped.

Also removes classes/functions that griffe reports as direct members but are
actually imported from a different pyrit module (the same short name is
defined in another module in the index).
"""
for module in modules:
mod_name = module.get("name", "")
resolved_members: list[dict] = []
for member in module.get("members", []):
kind = member.get("kind", "")
name = member.get("name", "")

if kind == "alias":
target = member.get("target", "")
if not target.startswith("pyrit."):
continue # External import (stdlib, third-party) – skip
if target in definition_index:
defn = definition_index[target].copy()
defn["name"] = name
resolved_members.append(defn)
elif name and name[0].isupper():
resolved_members.append({"name": name, "kind": "class"})
elif kind in ("class", "function"):
# Keep only if this module's tree contains a definition.
# A member defined in this module or its children is native;
# appearances in unrelated modules are just imports.
defining_modules = name_to_modules.get(name, [])
is_native = not defining_modules or any(
m == mod_name or m.startswith(mod_name + ".") for m in defining_modules
)
if is_native:
resolved_members.append(member)
else:
resolved_members.append(member)

module["members"] = resolved_members


def _expand_module(module: dict) -> list[dict]:
"""Recursively expand pure-aggregate modules into their children.

A pure-aggregate module has only submodule members and no direct public API
(classes, functions, aliases). Its children are returned instead, recursing
further if a child is also a pure aggregate.
"""
members = module.get("members", [])
has_api = any(m.get("kind") in ("class", "function", "alias") for m in members)
submodules = [m for m in members if m.get("kind") == "module"]

if has_api or not submodules:
# Module has its own API, or is a leaf – keep it (filter empty later)
return [module]

# Pure aggregate – recurse into children
result: list[dict] = []
for sub in submodules:
result.extend(_expand_module(sub))
return result


def collect_top_level_modules(api_json_dir: Path) -> list[dict]:
"""Collect top-level modules from aggregate JSON files.

When pydoc2json.py runs with --submodules, it produces a single JSON file
(e.g. pyrit_all.json) whose members are submodules. This function recursively
splits those nested submodules into individual JSON files so that each
submodule gets its own API reference page.
(e.g. pyrit_all.json) whose members are submodules. We only generate pages
for the public packages users import from, not for deeply nested internal
submodules whose content is re-exported by the parent.

Pure-aggregate modules (those with only submodule members) are recursively
expanded so their children with real API surface get their own pages.
"""
modules: list[dict] = []
for jf in sorted(api_json_dir.glob("*.json")):
data = json.loads(jf.read_text(encoding="utf-8"))
_split_submodules(data, jf.name, api_json_dir)
modules.extend(_expand_module(data))


def _split_submodules(data: dict, source_name: str, api_json_dir: Path) -> None:
"""Recursively extract and write submodule members to individual JSON files."""
for member in data.get("members", []):
if member.get("kind") != "module":
continue
sub_name = member["name"]
sub_path = api_json_dir / f"{sub_name}.json"
if not sub_path.exists():
sub_path.write_text(json.dumps(member, indent=2, default=str), encoding="utf-8")
print(f"Split {sub_name} from {source_name}")
# Recurse into nested submodules
_split_submodules(member, source_name, api_json_dir)
# Drop excluded and empty modules
return [
m
for m in modules
if not any(m.get("name", "").startswith(ex) for ex in EXCLUDED_MODULES)
and any(member.get("kind") in ("class", "function", "alias") for member in m.get("members", []))
]


def main() -> None:
API_MD_DIR.mkdir(parents=True, exist_ok=True)

# Split aggregate JSON files (e.g. pyrit_all.json) into per-module files
split_aggregate_json(API_JSON_DIR)

# Exclude aggregate files that only contain submodules (no direct classes/functions)
json_files = sorted(API_JSON_DIR.glob("*.json"))
if not json_files:
print("No JSON files found in", API_JSON_DIR)
return

# Collect module data, skipping pure-aggregate files
modules = []
modules = collect_top_level_modules(API_JSON_DIR)

# Build a lookup of all definitions and resolve aliases to their targets
definition_index: dict = {}
name_to_modules: dict[str, list[str]] = {}
for jf in json_files:
data = json.loads(jf.read_text(encoding="utf-8"))
_build_definition_index(data, definition_index, name_to_modules)
_resolve_aliases(modules, definition_index, name_to_modules)

# Generate per-module pages
for data in modules:
mod_name = data["name"]
slug = mod_name.replace(".", "_")
md_path = API_MD_DIR / f"{slug}.md"
content = render_module(data)
members = data.get("members", [])
# Skip files whose members are all submodules (aggregates like pyrit_all.json)
non_module_members = [m for m in members if m.get("kind") != "module"]
if not non_module_members and any(m.get("kind") == "module" for m in members):
continue
modules.append(data)
rendered_count = sum(1 for m in members if m.get("kind") in ("class", "function"))
md_path.write_text(content, encoding="utf-8")
print(f"Written {md_path} ({rendered_count} members)")

# Generate index page
index_parts = ["# API Reference\n"]
for data in modules:
mod_name = data["name"]
members = data.get("members", [])
member_count = len(members)
slug = mod_name.replace(".", "_")
classes = [m["name"] for m in members if m.get("kind") == "class"][:8]
preview = ", ".join(f"`{c}`" for c in classes)
if len(classes) < member_count:
preview += f" ... ({member_count} total)"

classes = [f"`{m['name']}`" for m in members if m.get("kind") == "class"]
functions = [f"`{m['name']}()`" for m in members if m.get("kind") == "function"]
rendered_count = len(classes) + len(functions)
preview_items = (classes + functions)[:8]
preview = ", ".join(preview_items)
if rendered_count > len(preview_items):
preview += f" ... ({rendered_count} total)"

index_parts.append(f"## [{mod_name}]({slug}.md)\n")
if preview:
index_parts.append(preview + "\n")
else:
index_parts.append("_No public API members detected._\n")

index_path = API_MD_DIR / "index.md"
index_path.write_text("\n".join(index_parts), encoding="utf-8")
print(f"Written {index_path}")

# Generate per-module pages
for data in modules:
mod_name = data["name"]
members = data.get("members", [])
# Skip modules with no members and no meaningful docstring
ds_text = (data.get("docstring") or {}).get("text", "")
if not members and len(ds_text) < 50:
continue
slug = mod_name.replace(".", "_")
md_path = API_MD_DIR / f"{slug}.md"
content = render_module(data)
md_path.write_text(content, encoding="utf-8")
print(f"Written {md_path} ({len(members)} members)")


if __name__ == "__main__":
main()
36 changes: 34 additions & 2 deletions build_scripts/pydoc2json.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,33 @@ def class_to_dict(cls: griffe.Class) -> dict:
return result


def _resolve_alias_from_source(target_path: str) -> dict | None:
"""Try to resolve an unresolvable alias by loading the target .py file directly.

When griffe cannot resolve an alias (e.g. due to missing __init__.py in
namespace packages), fall back to parsing the individual source file and
extracting the class or function definition.
"""
parts = target_path.rsplit(".", 1)
if len(parts) != 2:
return None
module_path, member_name = parts
source_file = Path(module_path.replace(".", "/") + ".py")
if not source_file.exists():
return None
try:
code = source_file.read_text(encoding="utf-8")
file_mod = griffe.visit(module_path, code=code, filepath=source_file)
member = file_mod.members.get(member_name)
if isinstance(member, griffe.Class):
return class_to_dict(member)
if isinstance(member, griffe.Function):
return function_to_dict(member)
except Exception:
pass
return None


def module_to_dict(mod: griffe.Module, include_submodules: bool = False) -> dict:
"""Convert a griffe Module to a structured dict."""
result = {
Expand All @@ -167,8 +194,13 @@ def module_to_dict(mod: griffe.Module, include_submodules: bool = False) -> dict
elif isinstance(target, griffe.Function):
result["members"].append(function_to_dict(target))
except Exception:
# Unresolvable alias — just record the name
result["members"].append({"name": name, "kind": "alias", "target": str(member.target_path)})
# Griffe cannot resolve (e.g. namespace package) — try source file
resolved = _resolve_alias_from_source(str(member.target_path))
if resolved:
resolved["name"] = name
result["members"].append(resolved)
else:
result["members"].append({"name": name, "kind": "alias", "target": str(member.target_path)})
elif isinstance(member, griffe.Module) and include_submodules:
result["members"].append(module_to_dict(member, include_submodules=True))
except Exception as e:
Expand Down
Loading
Loading