Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix python typing stubs generation for CUDA modules #24022

Merged
merged 6 commits into from Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 62 additions & 6 deletions modules/python/src2/typing_stubs_generation/api_refinement.py
Expand Up @@ -2,14 +2,17 @@
"apply_manual_api_refinement"
]

from typing import Sequence, Callable
from typing import cast, Sequence, Callable, Iterable

from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode,
ClassProperty, PrimitiveTypeNode)
from .ast_utils import find_function_node, SymbolName
from .nodes import (NamespaceNode, FunctionNode, OptionalTypeNode, TypeNode,
ClassProperty, PrimitiveTypeNode, ASTNodeTypeNode,
AggregatedTypeNode)
from .ast_utils import (find_function_node, SymbolName,
for_each_function_overload)


def apply_manual_api_refinement(root: NamespaceNode) -> None:
refine_cuda_module(root)
export_matrix_type_constants(root)
# Export OpenCV exception class
builtin_exception = root.add_class("Exception")
Expand Down Expand Up @@ -57,13 +60,65 @@ def _make_optional_arg(root_node: NamespaceNode,
continue

overload.arguments[arg_idx].type_node = OptionalTypeNode(
overload.arguments[arg_idx].type_node
cast(TypeNode, overload.arguments[arg_idx].type_node)
)

return _make_optional_arg


def _find_argument_index(arguments: Sequence[FunctionNode.Arg], name: str) -> int:
def refine_cuda_module(root: NamespaceNode) -> None:
def fix_cudaoptflow_enums_names() -> None:
for class_name in ("NvidiaOpticalFlow_1_0", "NvidiaOpticalFlow_2_0"):
if class_name not in cuda_root.classes:
continue
opt_flow_class = cuda_root.classes[class_name]
_trim_class_name_from_argument_types(
for_each_function_overload(opt_flow_class), class_name
)

def fix_namespace_usage_scope(cuda_ns: NamespaceNode) -> None:
USED_TYPES = ("GpuMat", "Stream")

def fix_type_usage(type_node: TypeNode) -> None:
if isinstance(type_node, AggregatedTypeNode):
for item in type_node.items:
fix_type_usage(item)
if isinstance(type_node, ASTNodeTypeNode):
if type_node._typename in USED_TYPES:
type_node._typename = f"cuda_{type_node._typename}"

for overload in for_each_function_overload(cuda_ns):
if overload.return_type is not None:
fix_type_usage(overload.return_type.type_node)
for type_node in [arg.type_node for arg in overload.arguments
if arg.type_node is not None]:
fix_type_usage(type_node)

if "cuda" not in root.namespaces:
return
cuda_root = root.namespaces["cuda"]
fix_cudaoptflow_enums_names()
for ns in [ns for ns_name, ns in root.namespaces.items()
if ns_name.startswith("cuda")]:
fix_namespace_usage_scope(ns)


def _trim_class_name_from_argument_types(
overloads: Iterable[FunctionNode.Overload],
class_name: str
) -> None:
separator = f"{class_name}_"
for overload in overloads:
for arg in [arg for arg in overload.arguments
if arg.type_node is not None]:
ast_node = cast(ASTNodeTypeNode, arg.type_node)
if class_name in ast_node.ctype_name:
fixed_name = ast_node._typename.split(separator)[-1]
ast_node._typename = fixed_name


def _find_argument_index(arguments: Sequence[FunctionNode.Arg],
name: str) -> int:
for i, arg in enumerate(arguments):
if arg.name == name:
return i
Expand All @@ -76,6 +131,7 @@ def _find_argument_index(arguments: Sequence[FunctionNode.Arg], name: str) -> in
SymbolName(("cv", ), (), "resize"): make_optional_arg("dsize"),
SymbolName(("cv", ), (), "calcHist"): make_optional_arg("mask"),
}

ERROR_CLASS_PROPERTIES = (
ClassProperty("code", PrimitiveTypeNode.int_(), False),
ClassProperty("err", PrimitiveTypeNode.str_(), False),
Expand Down
29 changes: 28 additions & 1 deletion modules/python/src2/typing_stubs_generation/ast_utils.py
@@ -1,5 +1,5 @@
from typing import (NamedTuple, Sequence, Tuple, Union, List,
Dict, Callable, Optional)
Dict, Callable, Optional, Generator)
import keyword

from .nodes import (ASTNode, NamespaceNode, ClassNode, FunctionNode,
Expand Down Expand Up @@ -404,6 +404,33 @@ def update_full_export_name(class_node: ClassNode) -> None:
return enum_export_name, namespace_node.full_export_name


def for_each_class(
node: Union[NamespaceNode, ClassNode]
) -> Generator[ClassNode, None, None]:
for cls in node.classes.values():
yield cls
if len(cls.classes):
yield from for_each_class(cls)


def for_each_function(
node: Union[NamespaceNode, ClassNode],
traverse_class_nodes: bool = True
) -> Generator[FunctionNode, None, None]:
yield from node.functions.values()
if traverse_class_nodes:
for cls in for_each_class(node):
yield from for_each_function(cls)


def for_each_function_overload(
node: Union[NamespaceNode, ClassNode],
traverse_class_nodes: bool = True
) -> Generator[FunctionNode.Overload, None, None]:
for func in for_each_function(node, traverse_class_nodes):
yield from func.overloads


if __name__ == '__main__':
import doctest
doctest.testmod()
53 changes: 18 additions & 35 deletions modules/python/src2/typing_stubs_generation/generation.py
Expand Up @@ -3,17 +3,20 @@
from io import StringIO
from pathlib import Path
import re
from typing import (Generator, Type, Callable, NamedTuple, Union, Set, Dict,
from typing import (Type, Callable, NamedTuple, Union, Set, Dict,
Collection, Tuple, List)
import warnings

from .ast_utils import get_enclosing_namespace, get_enum_module_and_export_name
from .ast_utils import (get_enclosing_namespace,
get_enum_module_and_export_name,
for_each_function_overload,
for_each_class)

from .predefined_types import PREDEFINED_TYPES
from .api_refinement import apply_manual_api_refinement

from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode, FunctionNode,
EnumerationNode, ConstantNode)
from .nodes import (ASTNode, ASTNodeType, NamespaceNode, ClassNode,
FunctionNode, EnumerationNode, ConstantNode)

from .nodes.type_node import (TypeNode, AliasTypeNode, AliasRefTypeNode,
AggregatedTypeNode, ASTNodeTypeNode,
Expand Down Expand Up @@ -105,8 +108,9 @@ def _generate_typing_stubs(root: NamespaceNode, output_path: Path) -> None:

# NOTE: Enumerations require special handling, because all enumeration
# constants are exposed as module attributes
has_enums = _generate_section_stub(StubSection("# Enumerations", EnumerationNode),
root, output_stream, 0)
has_enums = _generate_section_stub(
StubSection("# Enumerations", EnumerationNode), root, output_stream, 0
)
# Collect all enums from class level and export them to module level
for class_node in root.classes.values():
if _generate_enums_from_classes_tree(class_node, output_stream, indent=0):
Expand Down Expand Up @@ -536,30 +540,6 @@ def check_overload_presence(node: Union[NamespaceNode, ClassNode]) -> bool:
return True
return False


def _for_each_class(node: Union[NamespaceNode, ClassNode]) \
-> Generator[ClassNode, None, None]:
for cls in node.classes.values():
yield cls
if len(cls.classes):
yield from _for_each_class(cls)


def _for_each_function(node: Union[NamespaceNode, ClassNode]) \
-> Generator[FunctionNode, None, None]:
for func in node.functions.values():
yield func
for cls in node.classes.values():
yield from _for_each_function(cls)


def _for_each_function_overload(node: Union[NamespaceNode, ClassNode]) \
-> Generator[FunctionNode.Overload, None, None]:
for func in _for_each_function(node):
for overload in func.overloads:
yield overload


def _collect_required_imports(root: NamespaceNode) -> Set[str]:
"""Collects all imports required for classes and functions typing stubs
declarations.
Expand All @@ -582,7 +562,7 @@ def _add_required_usage_imports(type_node: TypeNode, imports: Set[str]):
has_overload = check_overload_presence(root)
# if there is no module-level functions with overload, check its presence
# during class traversing, including their inner-classes
for cls in _for_each_class(root):
for cls in for_each_class(root):
if not has_overload and check_overload_presence(cls):
has_overload = True
required_imports.add("import typing")
Expand All @@ -600,8 +580,9 @@ def _add_required_usage_imports(type_node: TypeNode, imports: Set[str]):
if has_overload:
required_imports.add("import typing")
# Importing modules required to resolve functions arguments
for overload in _for_each_function_overload(root):
for arg in filter(lambda a: a.type_node is not None, overload.arguments):
for overload in for_each_function_overload(root):
for arg in filter(lambda a: a.type_node is not None,
overload.arguments):
_add_required_usage_imports(arg.type_node, required_imports) # type: ignore
if overload.return_type is not None:
_add_required_usage_imports(overload.return_type.type_node,
Expand All @@ -625,11 +606,13 @@ def _reexport_submodule(ns: NamespaceNode) -> None:

_reexport_submodule(root)

# Special cases, symbols defined in possible pure Python submodules should be
# Special cases, symbols defined in possible pure Python submodules
# should be
root.reexported_submodules_symbols["mat_wrapper"].append("Mat")


def _write_reexported_symbols_section(module: NamespaceNode, output_stream: StringIO) -> None:
def _write_reexported_symbols_section(module: NamespaceNode,
output_stream: StringIO) -> None:
"""Write re-export section for the given module.

Re-export statements have from `from module_name import smth as smth`.
Expand Down
Expand Up @@ -22,6 +22,10 @@
PrimitiveTypeNode.int_("uchar"),
PrimitiveTypeNode.int_("unsigned"),
PrimitiveTypeNode.int_("int64"),
PrimitiveTypeNode.int_("uint8_t"),
PrimitiveTypeNode.int_("int8_t"),
PrimitiveTypeNode.int_("int32_t"),
PrimitiveTypeNode.int_("uint32_t"),
PrimitiveTypeNode.int_("size_t"),
PrimitiveTypeNode.float_("float"),
PrimitiveTypeNode.float_("double"),
Expand Down