-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][tosa] Allow int64 index tensors in gather/scatter #167894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This commit ensures that gather and scatter operations with int64 index tensors can be created. This aligns with the EXT_INT64 extension. Change-Id: Iea5e72dc9f1c4f085755325a8f9177df2e7fb8d7
|
@llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesThis commit ensures that gather and scatter operations with int64 index tensors can be created. This aligns with the EXT_INT64 extension. Full diff: https://github.com/llvm/llvm-project/pull/167894.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 2b36b2c5113e1..bb8faf01802fa 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2271,7 +2271,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
let arguments = (ins
Tosa_Tensor3D:$values,
- Tosa_Int32Tensor2D:$indices
+ Tosa_IndexTensor2D:$indices
);
let results = (outs
@@ -2308,7 +2308,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
let arguments = (ins
Tosa_Tensor3D:$values_in,
- Tosa_Int32Tensor2D:$indices,
+ Tosa_IndexTensor2D:$indices,
Tosa_Tensor3D:$input
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 414b51bf4b135..266a9e3a7d946 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,10 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[
def Tosa_TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
-def Tosa_Int32TensorUpto4D : AnyTypeOf<[
- Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
-def Tosa_Int32Tensor2D : AnyTypeOf<[
- Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>;
+def Tosa_IndexTensor2D : AnyTypeOf<[
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>;
def Tosa_TensorAtLeast1D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index b2a71ab882230..a4591f7ffd393 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -757,10 +757,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
}
// -----
-// CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
- return %0 : tensor<13x52x3xf32>
+// CHECK-LABEL: gather_int64
+func.func @test_gather_int64(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xf32> {
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi64>) -> tensor<13x26x3xf32>
+ return %0 : tensor<13x26x3xf32>
}
// -----
@@ -770,6 +770,20 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
return %0 : tensor<13x26x3xf32>
}
+// -----
+// CHECK-LABEL: scatter
+func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+ return %0 : tensor<13x52x3xf32>
+}
+
+// -----
+// CHECK-LABEL: scatter_int64
+func.func @test_scatter_int64(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi64>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+ return %0 : tensor<13x52x3xf32>
+}
+
// -----
// CHECK-LABEL: scatter_unranked_indices
func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
|
|
@llvm/pr-subscribers-mlir-tosa Author: Luke Hutton (lhutton1) ChangesThis commit ensures that gather and scatter operations with int64 index tensors can be created. This aligns with the EXT_INT64 extension. Full diff: https://github.com/llvm/llvm-project/pull/167894.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 2b36b2c5113e1..bb8faf01802fa 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2271,7 +2271,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
let arguments = (ins
Tosa_Tensor3D:$values,
- Tosa_Int32Tensor2D:$indices
+ Tosa_IndexTensor2D:$indices
);
let results = (outs
@@ -2308,7 +2308,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
let arguments = (ins
Tosa_Tensor3D:$values_in,
- Tosa_Int32Tensor2D:$indices,
+ Tosa_IndexTensor2D:$indices,
Tosa_Tensor3D:$input
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 414b51bf4b135..266a9e3a7d946 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,10 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[
def Tosa_TensorUpto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
-def Tosa_Int32TensorUpto4D : AnyTypeOf<[
- Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
-def Tosa_Int32Tensor2D : AnyTypeOf<[
- Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>;
+def Tosa_IndexTensor2D : AnyTypeOf<[
+ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>;
def Tosa_TensorAtLeast1D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index b2a71ab882230..a4591f7ffd393 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -757,10 +757,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
}
// -----
-// CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
- %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
- return %0 : tensor<13x52x3xf32>
+// CHECK-LABEL: gather_int64
+func.func @test_gather_int64(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xf32> {
+ %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi64>) -> tensor<13x26x3xf32>
+ return %0 : tensor<13x26x3xf32>
}
// -----
@@ -770,6 +770,20 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
return %0 : tensor<13x26x3xf32>
}
+// -----
+// CHECK-LABEL: scatter
+func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+ return %0 : tensor<13x52x3xf32>
+}
+
+// -----
+// CHECK-LABEL: scatter_int64
+func.func @test_scatter_int64(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+ %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi64>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+ return %0 : tensor<13x52x3xf32>
+}
+
// -----
// CHECK-LABEL: scatter_unranked_indices
func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
|
|
LGTM |
|
Thanks for the quick fix! Also LGTM on my side. |
psunn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
This commit ensures that gather and scatter operations with int64 index tensors can be created. This aligns with the EXT_INT64 extension.