From 5c05730a821b06273d6ce495a6262060b1e1f63a Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 30 Sep 2020 17:44:09 -0400 Subject: [PATCH] PoC for split GatherArgumentsResult and CppSignature Signed-off-by: Edward Z. Yang --- tools/codegen/api/cpp.py | 41 +++++++++++++------ tools/codegen/api/types.py | 80 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 11 deletions(-) diff --git a/tools/codegen/api/cpp.py b/tools/codegen/api/cpp.py index d8445f02ee54..7290a3d6752d 100644 --- a/tools/codegen/api/cpp.py +++ b/tools/codegen/api/cpp.py @@ -191,16 +191,29 @@ def argument(a: Union[Argument, TensorOptionsArguments, ThisArgument]) -> CppArg else: assert_never(a) -def group_arguments( - func: FunctionSchema, *, method: bool = False -) -> Sequence[Union[Argument, TensorOptionsArguments, ThisArgument]]: - args: List[Union[Argument, ThisArgument, TensorOptionsArguments]] = [] +class GatherArgumentsResult(NamedTuple): + arguments: Sequence[Union[Argument, ThisArgument], ...] + gathered_arguments: Optional[Sequence[Union[Argument, TensorOptionsArguments, ThisArgument], ...]] + +def gather_arguments( + func: FunctionSchema, *, method: bool +) -> GatherArgumentsResult: + args: List[Union[Argument, ThisArgument]] = [] + gathered_args: List[Union[Argument, ThisArgument, TensorOptionsArguments]] = [] + args.extend(func.out_arguments) + gathered_args.extend(func.out_arguments) if method: - args.extend(ThisArgument(a) if a.name == "self" else a for a in func.arguments) + func_arguments = [ThisArgument(a) if a.name == "self" else a for a in func.arguments] else: - args.extend(func.arguments) + func_arguments = func.arguments + + args.extend(func_arguments) + gathered_args.extend(func_arguments) + + args.extend(func.kwarg_only_arguments) + has_gathered_args = False # group up arguments for tensor options @@ -220,7 +233,8 @@ def pred(name: str, ty: Type) -> Callable[[Argument], bool]: # And the next len(predicates) arguments look like TensorOptions arguments if all(p(a) for p, a in zip(predicates, func.kwarg_only_arguments[i : i + len(predicates)])): # Group them together as one argument - args.append(TensorOptionsArguments( + has_gathered_args = True + gathered_args.append(TensorOptionsArguments( dtype=func.kwarg_only_arguments[i], layout=func.kwarg_only_arguments[i + 1], device=func.kwarg_only_arguments[i + 2], @@ -228,11 +242,16 @@ def pred(name: str, ty: Type) -> Callable[[Argument], bool]: )) i += len(predicates) continue - args.append(func.kwarg_only_arguments[i]) + gathered_args.append(func.kwarg_only_arguments[i]) i += 1 - return args + if has_gathered_args: + return GatheredArgumentsResult(args, gathered_args) + else: + return GatheredArgumentsResult(args, None) +""" # Convert arguments to C++ API form -def arguments(func: FunctionSchema, *, method: bool = False) -> Sequence[CppArgument]: - return list(map(argument, group_arguments(func, method=method))) +def arguments(func: FunctionSchema, *, method: bool, gathered: bool) -> Sequence[CppArgument]: + return list(map(argument, group_arguments(func, method=method, gathered=gathered))) +""" diff --git a/tools/codegen/api/types.py b/tools/codegen/api/types.py index cb315cfc7525..42ba8b2e1312 100644 --- a/tools/codegen/api/types.py +++ b/tools/codegen/api/types.py @@ -2,6 +2,9 @@ from dataclasses import dataclass from typing import Optional, Union, Sequence +# Functions only, no types +import tools.codegen.api.cpp as cpp + # Represents the implicit *this argument for method calls in C++ API @dataclass(frozen=True) class ThisArgument: @@ -49,6 +52,83 @@ class CppExpr: type: str expr: str +# A CppSignature is very similar to a FunctionSchema, but it is +# augmented with decisions about defaulting and overloads that are C++ +# specific (for example, single functions in native functions +# may desugar into multiple C++ overloads). There is a CppSignature +# per C++ overload, and this class contains enough information to +# distinguish between these overloads +@dataclass(frozen=True) +class CppSignature: + # The schema this signature is derived from + func: FunctionSchema + # If this signature is a method, this is not None and contains + # the corresponding ThisArgument for this signature + method: Optional[ThisArgument] + + # Some cached stuff. Kind of creeky + cpp_arguments: Tuple[CppArgument, ...] + cpp_return_type: str + + # Read-only on arguments important to enable covariance + @staticmethod + def from_arguments( + func: FunctionSchema, + arguments: Sequence[Argument, TensorOptionsArguments, ThisArgument], + *, + strip_defaults: bool + ) -> 'CppSignature': + + def maybe_strip_default(a: CppArgument) -> CppArgument: + if strip_defaults: + return CppArgument( + type=a.type, + name=a.name, + default=None, + argument=a.argument, + ) + else: + return a + + return CppSignature( + func=func, + method=next((a for a in arguments if isinstance(a, ThisArgument)), None), + cpp_arguments=tuple( + maybe_strip_default(cpp.argument(a)) for a in arguments if not isinstance(a, ThisArgument)), + cpp_return_type=cpp.returns_type(func.returns), + ) + +@dataclass(frozen=True) +class CppSignatureGroup: + func: FunctionSchema + signature: CppSignature + gathered_signature: Optional[CppSignature] + + def user_signature(self) -> CppSignature: + if self.gathered_signature is not None: + return self.gathered_signature + else: + return self.signature + + @staticmethod + def from_schema(func: FunctionSchema, *, method: bool) -> 'CppSignatureGroup': + r = cpp.gather_arguments(func, method=method) + gathered_signature: Optional[CppSignature] = None + r_arguments = r.arguments + # BTW: this faffing about is a pretty good indication that + # signature should be the optional one, and gathered signature + # the non-optional one + strip_defaults = False + if r.gathered_arguments is not None: + strip_defaults = True + gathered_signature = CppSignature.from_arguments(func, r.gathered_arguments, method=method, strip_defaults = False) + signature = CppSignature.from_arguments(func, r_arguments, method=method, strip_defaults=strip_defaults) + return CppSignatureGroup( + func=func, + signature=signature, + gathered_signature=gather_signature, + ) + @dataclass(frozen=True) class DispatcherExpr: type: str