Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified generic Tensor class #311

Open
justinchuby opened this issue Jan 12, 2023 · 14 comments
Open

Unified generic Tensor class #311

justinchuby opened this issue Jan 12, 2023 · 14 comments
Labels
enhancement New feature or request topic: api topic: syntax onnxscript syntax topic: torch_lib Related to the torch/aten function lib in development topic: typing Typing related issues
Milestone

Comments

@justinchuby
Copy link
Collaborator

justinchuby commented Jan 12, 2023

Create turn Tensor to a generic class to unify function definition and evaluator interfaces.

Tensor to represent a symbolic tensor by default, optionally made concrete by initializing with an numpy array

class Tensor(Generic[T])
    @property
    def rank(self):
        ...
    ...

Usage

typing definition:

RealType = Union[
    BFLOAT16,
    FLOAT16,
    FLOAT,
    DOUBLE,
    INT8,
    INT16,
    INT32,
    INT64,
]

TReal = TypeVar("TReal", bound=RealType)

In function definition:

@torch_op("aten::add")
def aten_add(self: Tensor[TReal], other: Tensor[TReal], alpha: float = 1.0) -> Tensor[TReal]:
    # add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
    if alpha != 1:
        other = op.Mul(other, alpha)
    return op.Add(self, other)

Traced only functions:

@torch_op("aten::transpose", trace_only=True)
def aten_transpose(self: Tensor[Any], dim0: int, dim1: int):
    # transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)

    # Use the Tensor API
    self_rank = self.rank

    if self_rank == 0:
        result = self
    else:
        # Python code, change when onnxscript supports this
        dims = list(range(self_rank))
        dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
        # Python code ends

        result = op.Transpose(self, perm=dims)

    return result

Evaluators operate on concrete Tensors

cc @gramalingam @fatcat-z @xiaowuhu @BowenBao @titaiwangms @abock

@justinchuby justinchuby added enhancement New feature or request topic: torch_lib Related to the torch/aten function lib in development topic: typing Typing related issues topic: syntax onnxscript syntax labels Jan 12, 2023
@BowenBao
Copy link
Contributor

  • Is it possible to annotate dimensions?
  • What if the tensor type needs to be a subset (superset, intersection, etc) of RealType?
  • Related to type system, what's the difference between INT64 and int? Can we make it interchangeable?

@gramalingam
Copy link
Collaborator

I have the same question. A parametric Tensor type would be perfect, the main question is whether we enable a way to specify dimensions. I thought there was an earlier proposal to use multiple generic-parameters when it becomes available?

@justinchuby
Copy link
Collaborator Author

justinchuby commented Jan 14, 2023

  1. Shapes is a little tricky with Generics, as explored in Fix TensorType to return new TensorTypes for shapes instead of TensorType instances #228 (comment). Tensor[INT64["M", 42]] seems to be the most realistic in this pre Python 3.11 world. (Generic) functions defined with typevars tend not to need shaping annotations.
  2. They can define a different typevar: https://github.com/microsoft/onnx-script/blob/dfaf174a7efaee415223dc628859f7c11318d1f6/onnxscript/function_libs/torch_aten/typing.py#L56-L66
  3. I think we can treat them as the same with the proposed generics system. INT64 however is less ambiguous, because int on windows is int32 it seems.

@gramalingam
Copy link
Collaborator

Re @BowenBao 's question:

what's the difference between INT64 and int?

In the existing system, INT64 represents a tensor (with int64 element types). In the proposed system, we should really say Tensor[int64] instead of Tensor[INT64]. (Just a minor point, once we decide what to do.)

@justinchuby
Copy link
Collaborator Author

justinchuby commented Jan 17, 2023

Some options I came up with @gramalingam @xadupre

Tensor[TReal, Shape["M", 1]] # Shape can be a non-generic with __class_get_item__ defined; mypy will still fail
Tensor[TReal, "M", 1] # Supported only after 3.11: python will fail in runtime if before 3.11
Tensor[TReal, Literal["M"], Literal[1]] # mypy will not fail but it will also not do anything useful
Tensor[INT64["M"]] # reusing existing; mypy will fail; doesn't support typevars
@script(
    shape={
        "self": ("M", "N"),
    }
)
def aten_add(self: Tensor[TReal], other: Tensor[TReal], alpha: float = 1.0) -> Tensor[TReal]:

cons: verbose, duplication

@xadupre
Copy link
Member

xadupre commented Jan 17, 2023

It could be great to able to indicate the output is of the exact same type of the input, not only real.

@justinchuby
Copy link
Collaborator Author

It could be great to able to indicate the output is of the exact same type of the input, not only real.

Yes. This is why I think typevars are great

@gramalingam
Copy link
Collaborator

Another option for shapes might be:

Annotated[Tensor[TReal], Shape["M", 1]]  # a bit more verbose version of option 1 above

Yes, using type-vars to indicate type-constraints is useful. One issue @abock ran into using type-vars (in his generation of signatures for all ONNX ops) was getting intellisense to be useful for usage of type-vars. May not be an issue if we use predefined vars like TReal that have a standardized meaning.

Another detail: using Tensor[TReal] would mean that we switch to using TReal to mean the element-types. This is doable (except for types like bfloat16/float16 which may not be well-supported), but just mentioning it.

@abock
Copy link
Contributor

abock commented Jan 18, 2023

Yes, type-vars in Python typing bring much from the world of traditional generics ala C# - we can express those constraints.

I think it's worth revisiting seeing exactly which distinct type unions fall out of the codegen and use that to inform what our generic unions could be and how they could be named. With sensible naming, the IntelliSense issue may be acceptable today, and I expect rich support for typings in the IDEs to improve over time as well.

@justinchuby
Copy link
Collaborator Author

I suggest we go with Tensor[TReal, Shape["M", 1]], where shape is Optional. Shape will not be a generic type but instead be a normal type with __class_get_item__ defined. This way we get python <3.11 to allow constructs like Shape["M", 1]

@justinchuby justinchuby self-assigned this Jun 27, 2023
@gramalingam
Copy link
Collaborator

gramalingam commented Jul 10, 2023

Will the use of strings (like "M") work with mypy (which used to complain about this ... it allowed int constants, but not string constants, as "dependent types")

Edited: Ok I notice it is implied above in an earlier message that this will fail mypy, unlike the more verbose Literal["M"].

@justinchuby
Copy link
Collaborator Author

justinchuby commented Jul 10, 2023

It will still fail mypy the same way we do now. However since Python 3.12(or 3.11?) will make this valid, I am hoping mypy will catch up

@justinchuby justinchuby added this to the 0.2 milestone Aug 29, 2023
@justinchuby
Copy link
Collaborator Author

@justinchuby justinchuby changed the title Unified Tensor class Unified generic Tensor class Sep 22, 2023
@justinchuby
Copy link
Collaborator Author

Also reference: https://github.com/google/jaxtyping

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request topic: api topic: syntax onnxscript syntax topic: torch_lib Related to the torch/aten function lib in development topic: typing Typing related issues
Projects
Development

No branches or pull requests

5 participants