Skip to content

Conversation

@AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Oct 2, 2024

From onnx.ScatterElements: nod-ai/SHARK-ModelDev#823
e2e test: nod-ai/SHARK-TestSuite#363

The final target examples, the index and src size is [?] dynamic, here we use static size to explain the algorithm.
torch.scatter.reduce step by step example:
self[index[i]] += src[i]

src = [1, 2, 3, 4, 5, 6]
index = [0, 1, 0, 1, 2, 1]
self = [1, 2, 3, 4]
Step 0:
self[index[0]] += src[0]
self[0] += 1  = 1+1 = 2
1+1 = 2
self = [2, 2, 3, 4])

Step 1:
self[index[1]] += src[1]
self[1] += 2  = 2+2 = 4
self = [2, 4, 3, 4])

Step 2:
self[index[2]] += src[2]
self[0] += 3  = 2+3 = 5
self = [5, 4, 3, 4])

Step 3:
self[index[3]] += src[3]
self[1] += 4  = 4+4 = 8
self = [5, 8, 3, 4])

Step 4:
self[index[4]] += src[4]
self[2] += 5  = 3+5 = 8
self = [5, 8, 8, 4])

Step 5:
self[index[5]] += src[5]
self[1] += 6  = 8+6 = 14
self = [5, 14, 8, 4])
  1. onnx-model.mlir
module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>} : () -> !torch.vtensor<[6],si64> 
    %1 = torch.operator "onnx.ScatterElements"(%arg0, %0, %arg1) {torch.onnx.axis = 0 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[4],f32>, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> 
    return %1 : !torch.vtensor<[4],f32>
  }
}

torch-mlir-opt --convert-torch-onnx-to-torch --torch-decompose-complex-ops --cse --canonicalize --convert-torch-to-linalg torch-model.mlir
2. torch-model.mlir

module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %true = torch.constant.bool true
    %str = torch.constant.str "sum"
    %0 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %int0 = torch.constant.int 0
    %1 = torch.aten.scatter_reduce.two %arg0, %int0, %0, %arg1, %str, %true : !torch.vtensor<[4],f32>, !torch.int, !torch.vtensor<[6],si64>, !torch.vtensor<[6],f32>, !torch.str, !torch.bool -> !torch.vtensor<[4],f32>
    return %1 : !torch.vtensor<[4],f32>
  }
}
  1. linalg-model.mlir
#map = affine_map<(d0) -> (d0, 0)>
#map1 = affine_map<(d0) -> (d0)>
module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %0 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[6],f32> -> tensor<6xf32>
    %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[4],f32> -> tensor<4xf32>
    %true = torch.constant.bool true
    %str = torch.constant.str "sum"
    %2 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %3 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[6],si64> -> tensor<6xi64>
    %int0 = torch.constant.int 0
    %c0 = arith.constant 0 : index
    %c6 = arith.constant 6 : index
    %c1 = arith.constant 1 : index
    %4 = arith.muli %c1, %c6 : index
    %5 = arith.index_cast %4 : index to i64
    %6 = arith.index_cast %5 : i64 to index
    %c0_0 = arith.constant 0 : index
    %c6_1 = arith.constant 6 : index
    %c1_2 = arith.constant 1 : index
    %7 = tensor.empty(%6) : tensor<?x1xi32>
    %c0_i32 = arith.constant 0 : i32
    %8 = linalg.fill ins(%c0_i32 : i32) outs(%7 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %9 = tensor.empty(%6) : tensor<?xf32>
    %cst = arith.constant 0.000000e+00 : f32
    %10 = linalg.fill ins(%cst : f32) outs(%9 : tensor<?xf32>) -> tensor<?xf32>
    %11:2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} outs(%8, %10 : tensor<?x1xi32>, tensor<?xf32>) {
    ^bb0(%out: i32, %out_13: f32):
      %16 = linalg.index 0 : index
      %17 = arith.remsi %16, %c6_1 : index
      %18 = arith.divsi %16, %c6_1 : index
      %extracted = tensor.extract %3[%17] : tensor<6xi64>
      %extracted_14 = tensor.extract %0[%17] : tensor<6xf32>
      %19 = arith.index_cast %17 : index to i64
      %20 = arith.trunci %19 : i64 to i32
      %21 = arith.trunci %extracted : i64 to i32
      linalg.yield %21, %extracted_14 : i32, f32
    } -> (tensor<?x1xi32>, tensor<?xf32>)
    %c0_3 = arith.constant 0 : index
    %c0_4 = arith.constant 0 : index
    %c1_5 = arith.constant 1 : index
    %c1_6 = arith.constant 1 : index
    %c1_7 = arith.constant 1 : index
    %12 = tensor.empty(%6) : tensor<?x1xi32>
    %c0_i32_8 = arith.constant 0 : i32
    %13 = linalg.fill ins(%c0_i32_8 : i32) outs(%12 : tensor<?x1xi32>) -> tensor<?x1xi32>
    %c0_9 = arith.constant 0 : index
    %dim = tensor.dim %11#0, %c0_9 : tensor<?x1xi32>
    %c1_10 = arith.constant 1 : index
    %c1_11 = arith.constant 1 : index
    %inserted_slice = tensor.insert_slice %11#0 into %13[0, 0] [%dim, 1] [1, 1] : tensor<?x1xi32> into tensor<?x1xi32>
    %c1_12 = arith.constant 1 : index
    %14 = tm_tensor.scatter {dimension_map = array<i64: 0>} unique_indices(false) ins(%11#1, %inserted_slice : tensor<?xf32>, tensor<?x1xi32>) outs(%1 : tensor<4xf32>) {
    ^bb0(%arg2: f32, %arg3: f32):
      %16 = arith.addf %arg2, %arg3 : f32
      tm_tensor.yield %16 : f32
    } -> tensor<4xf32>
    %cast = tensor.cast %14 : tensor<4xf32> to tensor<4xf32>
    %15 = torch_c.from_builtin_tensor %cast : tensor<4xf32> -> !torch.vtensor<[4],f32>
    return %15 : !torch.vtensor<[4],f32>
  }
}
tests model-run onnx-import torch-mlir iree-compile inference
onnx/operators/ScatterElements passed passed passed passed passed

Copy link
Collaborator

@Shukla-Gaurav Shukla-Gaurav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few minor comments!

@Shukla-Gaurav
Copy link
Collaborator

@AmosLewis Can you please add a end 2 end test case in SHARK-TestSuite to make sure this lowering works fine.

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Oct 4, 2024

based on the e2e op iree inference, the lowering is not correct. nod-ai/SHARK-TestSuite#363 (comment)
The trick I try to use here is to use self as the output of linalg.generic. Based on my understanding, after tensor.scatter, the output/self should updated, then in the next loop of linalg.generic, we have an updated output self as new input self. But based on the inference result, the output/self does not update successfully. In my example, the output self value is still [1,2,3,4], not expected [5,14,8,4]. Which means the trick does not work. The self is not updated. Have no idea how to fix it. Need you guys help. @rsuderman @Shukla-Gaurav @vivekkhandelwal1 @zjgarvey

%0 = src = [1, 2, 3, 4, 5, 6]
%1 = self = [1, 2, 3, 4]
%2 = index = [0, 1, 0, 1, 2, 1]

#map = affine_map<(d0) -> (d0)>
module {
  func.func @scatter_graph(%arg0: !torch.vtensor<[4],f32>, %arg1: !torch.vtensor<[6],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
    %0 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[6],f32> -> tensor<6xf32>
    %1 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[4],f32> -> tensor<4xf32>
    %str = torch.constant.str "sum"
    %2 = torch.vtensor.literal(dense<[0, 1, 0, 1, 2, 1]> : tensor<6xsi64>) : !torch.vtensor<[6],si64>
    %3 = torch_c.to_builtin_tensor %2 : !torch.vtensor<[6],si64> -> tensor<6xi64>
    %int0 = torch.constant.int 0
    %cast = tensor.cast %3 : tensor<6xi64> to tensor<?xi64>
    %cast_0 = tensor.cast %0 : tensor<6xf32> to tensor<?xf32>
    %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%cast, %cast_0 : tensor<?xi64>, tensor<?xf32>) outs(%1 : tensor<4xf32>) {
    ^bb0(%in: i64, %in_2: f32, %out: f32):
      %6 = arith.index_cast %in : i64 to index
      %extracted = tensor.extract %1[%6] : tensor<4xf32>
      %7 = arith.addf %extracted, %in_2 : f32
      %from_elements = tensor.from_elements %7 : tensor<1xf32>
      %from_elements_3 = tensor.from_elements %6 : tensor<1xindex>
      %scatter = tensor.scatter %from_elements into %1[%from_elements_3] scatter_dims([0]) unique : (tensor<1xf32>, tensor<4xf32>, tensor<1xindex>) -> tensor<4xf32>
      linalg.yield %out : f32
    } -> tensor<4xf32>
    %cast_1 = tensor.cast %4 : tensor<4xf32> to tensor<4xf32>
    %5 = torch_c.from_builtin_tensor %cast_1 : tensor<4xf32> -> !torch.vtensor<[4],f32>
    return %5 : !torch.vtensor<[4],f32>
  }
}

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Oct 4, 2024

Find a simple fix for this op. Just use the tm_tensor scatterreducetwo pass instead of writing linalg scatterreduce op. Here is the test result: nod-ai/SHARK-TestSuite#363

Status report for run: test-run using mode:onnx todtype:default backend:llvm-cpu

| tests                          | model-run   | onnx-import   | torch-mlir   | iree-compile   | inference   |
|:-------------------------------|:------------|:--------------|:-------------|:---------------|:------------|
| onnx/operators/ScatterElements | passed      | passed        | passed       | passed         | passed      |

…atterElements

This will enable the AtenScatterReduceTwoOp lowering to tm_tensor/linalg_ext
Remove the wrong AtenScatterReduce to linalg pass.
@AmosLewis AmosLewis changed the title [Linalg] Add torch.scatter.reduce to linalg lowering [ONNX] Change AtenScatterReduce to AtenScatterReduceTwoOp with tm_tensor pass for onnx.ScatterElements Oct 6, 2024
@AmosLewis AmosLewis merged commit f4840ed into llvm:main Oct 6, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants