Skip to content

Commit

Permalink
PoC for split GatherArgumentsResult and CppSignature
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
  • Loading branch information
ezyang committed Sep 30, 2020
1 parent 56af122 commit 5c05730
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 11 deletions.
41 changes: 30 additions & 11 deletions tools/codegen/api/cpp.py
Expand Up @@ -191,16 +191,29 @@ def argument(a: Union[Argument, TensorOptionsArguments, ThisArgument]) -> CppArg
else:
assert_never(a)

def group_arguments(
func: FunctionSchema, *, method: bool = False
) -> Sequence[Union[Argument, TensorOptionsArguments, ThisArgument]]:
args: List[Union[Argument, ThisArgument, TensorOptionsArguments]] = []
class GatherArgumentsResult(NamedTuple):
arguments: Sequence[Union[Argument, ThisArgument], ...]
gathered_arguments: Optional[Sequence[Union[Argument, TensorOptionsArguments, ThisArgument], ...]]

def gather_arguments(
func: FunctionSchema, *, method: bool
) -> GatherArgumentsResult:
args: List[Union[Argument, ThisArgument]] = []
gathered_args: List[Union[Argument, ThisArgument, TensorOptionsArguments]] = []

args.extend(func.out_arguments)
gathered_args.extend(func.out_arguments)

if method:
args.extend(ThisArgument(a) if a.name == "self" else a for a in func.arguments)
func_arguments = [ThisArgument(a) if a.name == "self" else a for a in func.arguments]
else:
args.extend(func.arguments)
func_arguments = func.arguments

args.extend(func_arguments)
gathered_args.extend(func_arguments)

args.extend(func.kwarg_only_arguments)
has_gathered_args = False

# group up arguments for tensor options

Expand All @@ -220,19 +233,25 @@ def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
# And the next len(predicates) arguments look like TensorOptions arguments
if all(p(a) for p, a in zip(predicates, func.kwarg_only_arguments[i : i + len(predicates)])):
# Group them together as one argument
args.append(TensorOptionsArguments(
has_gathered_args = True
gathered_args.append(TensorOptionsArguments(
dtype=func.kwarg_only_arguments[i],
layout=func.kwarg_only_arguments[i + 1],
device=func.kwarg_only_arguments[i + 2],
pin_memory=func.kwarg_only_arguments[i + 3],
))
i += len(predicates)
continue
args.append(func.kwarg_only_arguments[i])
gathered_args.append(func.kwarg_only_arguments[i])
i += 1

return args
if has_gathered_args:
return GatheredArgumentsResult(args, gathered_args)
else:
return GatheredArgumentsResult(args, None)

"""
# Convert arguments to C++ API form
def arguments(func: FunctionSchema, *, method: bool = False) -> Sequence[CppArgument]:
return list(map(argument, group_arguments(func, method=method)))
def arguments(func: FunctionSchema, *, method: bool, gathered: bool) -> Sequence[CppArgument]:
return list(map(argument, group_arguments(func, method=method, gathered=gathered)))
"""
80 changes: 80 additions & 0 deletions tools/codegen/api/types.py
Expand Up @@ -2,6 +2,9 @@
from dataclasses import dataclass
from typing import Optional, Union, Sequence

# Functions only, no types
import tools.codegen.api.cpp as cpp

# Represents the implicit *this argument for method calls in C++ API
@dataclass(frozen=True)
class ThisArgument:
Expand Down Expand Up @@ -49,6 +52,83 @@ class CppExpr:
type: str
expr: str

# A CppSignature is very similar to a FunctionSchema, but it is
# augmented with decisions about defaulting and overloads that are C++
# specific (for example, single functions in native functions
# may desugar into multiple C++ overloads). There is a CppSignature
# per C++ overload, and this class contains enough information to
# distinguish between these overloads
@dataclass(frozen=True)
class CppSignature:
# The schema this signature is derived from
func: FunctionSchema
# If this signature is a method, this is not None and contains
# the corresponding ThisArgument for this signature
method: Optional[ThisArgument]

# Some cached stuff. Kind of creeky
cpp_arguments: Tuple[CppArgument, ...]
cpp_return_type: str

# Read-only on arguments important to enable covariance
@staticmethod
def from_arguments(
func: FunctionSchema,
arguments: Sequence[Argument, TensorOptionsArguments, ThisArgument],
*,
strip_defaults: bool
) -> 'CppSignature':

def maybe_strip_default(a: CppArgument) -> CppArgument:
if strip_defaults:
return CppArgument(
type=a.type,
name=a.name,
default=None,
argument=a.argument,
)
else:
return a

return CppSignature(
func=func,
method=next((a for a in arguments if isinstance(a, ThisArgument)), None),
cpp_arguments=tuple(
maybe_strip_default(cpp.argument(a)) for a in arguments if not isinstance(a, ThisArgument)),
cpp_return_type=cpp.returns_type(func.returns),
)

@dataclass(frozen=True)
class CppSignatureGroup:
func: FunctionSchema
signature: CppSignature
gathered_signature: Optional[CppSignature]

def user_signature(self) -> CppSignature:
if self.gathered_signature is not None:
return self.gathered_signature
else:
return self.signature

@staticmethod
def from_schema(func: FunctionSchema, *, method: bool) -> 'CppSignatureGroup':
r = cpp.gather_arguments(func, method=method)
gathered_signature: Optional[CppSignature] = None
r_arguments = r.arguments
# BTW: this faffing about is a pretty good indication that
# signature should be the optional one, and gathered signature
# the non-optional one
strip_defaults = False
if r.gathered_arguments is not None:
strip_defaults = True
gathered_signature = CppSignature.from_arguments(func, r.gathered_arguments, method=method, strip_defaults = False)
signature = CppSignature.from_arguments(func, r_arguments, method=method, strip_defaults=strip_defaults)
return CppSignatureGroup(
func=func,
signature=signature,
gathered_signature=gather_signature,
)

@dataclass(frozen=True)
class DispatcherExpr:
type: str
Expand Down

0 comments on commit 5c05730

Please sign in to comment.