Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing support for SymInts #518

Closed
justinchuby opened this issue Mar 11, 2023 · 9 comments · Fixed by #484
Closed

Fixing support for SymInts #518

justinchuby opened this issue Mar 11, 2023 · 9 comments · Fixed by #484
Labels
topic: discussion For discussion topic: torch_lib Related to the torch/aten function lib in development

Comments

@justinchuby
Copy link
Contributor

justinchuby commented Mar 11, 2023

In PyTorch, SymInts or symbolic ints is a way of supporting dynamic attributes. In atenlib we tend to represent them as INT64 tensors to maintain their dynamic nature. However in fx, they are usually passed in as a list of nodes (a list of tensors).

It may be better if we use Sequence[INT64] to represent them to maintain consistency with the fx graph.

This representation further enables trace_only support for SymInts when we need to obtain a rank. (#517)
As an example in the native_layer_norm:

def aten_native_layer_norm(
    input: TReal,
    normalized_shape: Sequence[INT64],
    weight: Optional[TReal],
    bias: Optional[TReal],
    eps: float,
) -> Tuple[TReal, TReal, TReal]:
    """native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)"""

Here normalized_shape is a SymInt sequence. If we annotate it as Sequence[int], it will be considered an attribute by onnxscript and fail with a TypeError because we reject TorchScriptTensor types as attributes; if we annotate it as INT64, the Python line for taking the rank start_axis = -len(normalized_shape) will fail because TorchScriptTensor does not have len defined, and we cannot use op.Neg(op.Size(op.Shape(normalized_shape))) because start_axis is subsequently used as an attribute (so not dynamic) in LayerNormalization.

Change required

For this change, we need to update how we convert python scalar inputs to tensors: Whereas before we convert List[int] to INT64, we should now convert them to List[INT64] tensors (convert each element to INT64 then SequenceConstruct). Functions that take SymInts should use ConcatFromSequence to concatenate the sequence into a single tensor for its own use.

Potential issues

  1. I am not sure if all SymInt sequences are passed as lists? Can the length (aka. rank) be dynamic too which will force us to use INT64, or use INT64 -> SplitToSequence?
  2. We may make things slower for static int sequences. For this we need to constant fold ConcatFromSequence.

cc @titaiwangms @BowenBao @xiaowuhu @fatcat-z @gramalingam

@justinchuby justinchuby added the topic: torch_lib Related to the torch/aten function lib in development label Mar 11, 2023
@justinchuby justinchuby self-assigned this Mar 11, 2023
@justinchuby justinchuby added the topic: discussion For discussion label Mar 11, 2023
@justinchuby
Copy link
Contributor Author

@gramalingam @xiaowuhu this would change how we represent symints and change the recommendation from using INT64 to using Sequence[INT64]

@justinchuby
Copy link
Contributor Author

justinchuby commented Mar 11, 2023

@titaiwangms relevant for dynamic axes support, because the ConcatFromSequence node may now be the function's responsibility.

@justinchuby
Copy link
Contributor Author

justinchuby commented Mar 11, 2023

Hmm we can’t really use SequenceConstruct because we run into the no len issue again. Maybe we need to implement len for our tensor instead?

or just len for sequence construct

@justinchuby
Copy link
Contributor Author

justinchuby commented Mar 11, 2023

~Yeah since start_axis is an attribute we cannot even represent it as List[INT64]. ~ we can because we are just using the rank and not the tensors themselves. And this is trace only (static)

@justinchuby

This comment was marked as resolved.

@justinchuby
Copy link
Contributor Author

justinchuby commented Mar 11, 2023

I thought about this again. Maybe what’s better is to go with Ti-tai’s solution where we programmatically handle list of tensors, and implement len on onnx script Tensors.

This way the change is minimal and functions are kept simple.

@titaiwangms
Copy link
Contributor

titaiwangms commented Mar 11, 2023

Not fully understand the perspective of this issue and case, but have been working on symbolic fx.graph, and thought might be helpful to provide what I have now.

AFAIK, symbolic fx.graph generates aten::sym_size to replace real shape value of size. For example, size input in op.Expand is like [2, 3, 4] when the shape is fixed, but with symbolic graph, we will have [sym_size0, sym_size1, sym_size2], and each of them connecting to fx.Node. What the PR have now is to check if the value is List[INT64] or not, to decide it goes to op.Concat or op.Constant.

Current status is that It goes great with most of dynamic cases we had in addhoc, but breaks test_fx static shape cases with some input out of bounds problems which I am still looking into. Most likely, it's the differences between symbolic fx.graph and fake fx.graph causing the regression. We can have more discussion on Monday to sync up on this topic!

related issue: #481
PyTorch working PR: pytorch/pytorch#96350
ONNX-Script PR: #484

@titaiwangms
Copy link
Contributor

titaiwangms commented Mar 11, 2023

I thought about this again. Maybe what’s better is to go with Ti-tai’s solution where we programmatically handle list of tensors, and implement len on onnx script Tensors.

This way the change is minimal and functions are kept simple.

I think I kind of understand the concern here now. Yes, and I think stick with List[INT64] would make the code in converter side more clear as well.

@justinchuby
Copy link
Contributor Author

In the atenlib office hour, we decided that: 1. We use INT64 to represent sequence of symints. 2. Concat should be done by the exporter.

titaiwangms added a commit that referenced this issue Mar 21, 2023
Used by pytorch/pytorch#96350

fixes #518 fixes #390

---------

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic: discussion For discussion topic: torch_lib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants