Skip to content

Conversation

@sahas3
Copy link
Member

@sahas3 sahas3 commented Oct 14, 2024

  1. Negative indices for tensor indexing is handled by wrapping around the index values by checking their values at run time. Without the fix, there was a runtime error.
  2. Added a lit test to lock down the behavior.
  3. Updated the xfails_set for fx_importer_tosa config to lockdown the behavior with e2e test as well.

"THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY."

@sahas3
Copy link
Member Author

sahas3 commented Oct 14, 2024

Since I don't have permission to request reviews yet, I am tagging you @AmosLewis as you were the original author for adding support for index.Tensor op. Can you please take a look at this change? Thanks!

@sahas3
Copy link
Member Author

sahas3 commented Oct 21, 2024

Hi @eric-k256 can you please review this change or add appropriate reviewers? Thanks!

@eric-k256 eric-k256 requested a review from sjarus October 22, 2024 20:48
@eric-k256
Copy link
Collaborator

I've added @sjarus as a reviewer. It might be slightly delayed as this week is US LLVM dev conference. While waiting, it appears that your llvm external is out of date. It would be helpful to update that to speed the review.

@sahas3
Copy link
Member Author

sahas3 commented Oct 23, 2024

I've added @sjarus as a reviewer. It might be slightly delayed as this week is US LLVM dev conference. While waiting, it appears that your llvm external is out of date. It would be helpful to update that to speed the review.

Thanks, Eric. I've updated my branch rebasing main.

@sjarus
Copy link
Collaborator

sjarus commented Oct 24, 2024

Hi @sahas3 thanks so much for this contribution! I took a quick look at this. I'm curious about whether this code can also have compile-time and run-time forms, with the former predicated on being able to access the list at compile time and modifying it at compile time rather than by materializing the check in IR ?

@sahas3
Copy link
Member Author

sahas3 commented Oct 25, 2024

Thanks for the review @sjarus.

Are you thinking of possible optimization in scenarios like below?

func.func @main(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> attributes {torch.assume_strict_symbolic_shapes} {
    %0 = torch.vtensor.literal(dense<-1> : tensor<si64>) : !torch.vtensor<[],si64>
    %1 = torch.prim.ListConstruct %0 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
    %2 = torch.aten.index.Tensor_hacked_twin %arg0, %1 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
    return %2 : !torch.vtensor<[4,2],si64>
  }

producing the below output with --convert-torch-to-tosa

func.func @main(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> attributes {torch.assume_strict_symbolic_shapes} {
    %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64>
    %1 = "tosa.const"() <{value = dense<-1> : tensor<i64>}> : () -> tensor<i64>
    %2 = torch_c.from_builtin_tensor %1 : tensor<i64> -> !torch.vtensor<[],si64>
    %3 = torch.prim.ListConstruct %2 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
    %4 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[],si64> -> tensor<i64>
    %5 = tosa.cast %4 : (tensor<i64>) -> tensor<i32>
    %6 = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
    %7 = "tosa.const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
    %8 = tosa.add %7, %5 : (tensor<i32>, tensor<i32>) -> tensor<i32>
    %9 = tosa.greater %6, %5 : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %10 = tosa.select %9, %8, %5 : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %11 = tosa.reshape %10 {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
    %12 = tosa.reshape %0 {new_shape = array<i64: 1, 2, 8>} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64>
    %13 = tosa.reshape %11 {new_shape = array<i64: 1, 1>} : (tensor<1xi32>) -> tensor<1x1xi32>
    %14 = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
    %15 = tosa.mul %13, %14 {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32>
    %16 = tosa.reduce_sum %15 {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32>
    %17 = tosa.reshape %16 {new_shape = array<i64: 1, 1>} : (tensor<1x1xi32>) -> tensor<1x1xi32>
    %18 = tosa.gather %12, %17 : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64>
    %19 = tosa.reshape %18 {new_shape = array<i64: 4, 2>} : (tensor<1x1x8xi64>) -> tensor<4x2xi64>
    %20 = torch_c.from_builtin_tensor %19 : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64>
    return %20 : !torch.vtensor<[4,2],si64>
  }

In the full torch-backend-to-tosa-backend-pipeline pipeline this is optimized, once the (to)from_builtin_tensor ops are properly materialized in FinalizingBackendTypeConversion pass:

func.func @main(%arg0: tensor<2x4x2xi64>, %arg1: tensor<i64>) -> tensor<4x2xi64> {
    %0 = "tosa.const"() <{value = dense<1> : tensor<1x1xi32>}> : () -> tensor<1x1xi32>
    %1 = tosa.reshape %arg0 {new_shape = array<i64: 1, 2, 8>} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64>
    %2 = tosa.gather %1, %0 : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64>
    %3 = tosa.reshape %2 {new_shape = array<i64: 4, 2>} : (tensor<1x1x8xi64>) -> tensor<4x2xi64>
    return %3 : tensor<4x2xi64>
  }

So I think branching in this pass based on constness of index is not necessary. The presence of the (to)from_builtin_tensor ops in between def of the tosa.const and it's actual use may also make it tricky to infer that index is constant when we are in this pass.

@sjarus
Copy link
Collaborator

sjarus commented Oct 25, 2024

Thanks, that looks pretty sane to me. Approving.

@rafaelubalmw rafaelubalmw merged commit 2b01f8b into llvm:main Oct 25, 2024
3 checks passed
rahuls-cerebras added a commit that referenced this pull request Jan 3, 2025
… index.Tensor_hacked_twin for TorchToTosa lowering. (#3790)"

This reverts commit 2b01f8b.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants