Skip to content

Conversation

@lhutton1
Copy link
Contributor

This commit ensures that gather and scatter operations with int64 index tensors can be created. This aligns with the EXT_INT64 extension.

This commit ensures that gather and scatter operations
with int64 index tensors can be created. This aligns with
the EXT_INT64 extension.

Change-Id: Iea5e72dc9f1c4f085755325a8f9177df2e7fb8d7
@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit ensures that gather and scatter operations with int64 index tensors can be created. This aligns with the EXT_INT64 extension.


Full diff: https://github.com/llvm/llvm-project/pull/167894.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+2-4)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+18-4)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 2b36b2c5113e1..bb8faf01802fa 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2271,7 +2271,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values,
-    Tosa_Int32Tensor2D:$indices
+    Tosa_IndexTensor2D:$indices
   );
 
   let results = (outs
@@ -2308,7 +2308,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values_in,
-    Tosa_Int32Tensor2D:$indices,
+    Tosa_IndexTensor2D:$indices,
     Tosa_Tensor3D:$input
   );
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 414b51bf4b135..266a9e3a7d946 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,10 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[
 def Tosa_TensorUpto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
 
-def Tosa_Int32TensorUpto4D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
-def Tosa_Int32Tensor2D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>;
+def Tosa_IndexTensor2D : AnyTypeOf<[
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>;
 
 def Tosa_TensorAtLeast1D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index b2a71ab882230..a4591f7ffd393 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -757,10 +757,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
 }
 
 // -----
-// CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
-  return %0 : tensor<13x52x3xf32>
+// CHECK-LABEL: gather_int64
+func.func @test_gather_int64(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xf32> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi64>) -> tensor<13x26x3xf32>
+  return %0 : tensor<13x26x3xf32>
 }
 
 // -----
@@ -770,6 +770,20 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
   return %0 : tensor<13x26x3xf32>
 }
 
+// -----
+// CHECK-LABEL: scatter
+func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+  return %0 : tensor<13x52x3xf32>
+}
+
+// -----
+// CHECK-LABEL: scatter_int64
+func.func @test_scatter_int64(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi64>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+  return %0 : tensor<13x52x3xf32>
+}
+
 // -----
 // CHECK-LABEL: scatter_unranked_indices
 func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

This commit ensures that gather and scatter operations with int64 index tensors can be created. This aligns with the EXT_INT64 extension.


Full diff: https://github.com/llvm/llvm-project/pull/167894.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+2-4)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+18-4)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 2b36b2c5113e1..bb8faf01802fa 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2271,7 +2271,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values,
-    Tosa_Int32Tensor2D:$indices
+    Tosa_IndexTensor2D:$indices
   );
 
   let results = (outs
@@ -2308,7 +2308,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
 
   let arguments = (ins
     Tosa_Tensor3D:$values_in,
-    Tosa_Int32Tensor2D:$indices,
+    Tosa_IndexTensor2D:$indices,
     Tosa_Tensor3D:$input
   );
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 414b51bf4b135..266a9e3a7d946 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -202,10 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[
 def Tosa_TensorUpto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
 
-def Tosa_Int32TensorUpto4D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
-def Tosa_Int32Tensor2D : AnyTypeOf<[
-  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>;
+def Tosa_IndexTensor2D : AnyTypeOf<[
+  Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>;
 
 def Tosa_TensorAtLeast1D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index b2a71ab882230..a4591f7ffd393 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -757,10 +757,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
 }
 
 // -----
-// CHECK-LABEL: scatter
-func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
-  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
-  return %0 : tensor<13x52x3xf32>
+// CHECK-LABEL: gather_int64
+func.func @test_gather_int64(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xf32> {
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi64>) -> tensor<13x26x3xf32>
+  return %0 : tensor<13x26x3xf32>
 }
 
 // -----
@@ -770,6 +770,20 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
   return %0 : tensor<13x26x3xf32>
 }
 
+// -----
+// CHECK-LABEL: scatter
+func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+  return %0 : tensor<13x52x3xf32>
+}
+
+// -----
+// CHECK-LABEL: scatter_int64
+func.func @test_scatter_int64(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
+  %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi64>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
+  return %0 : tensor<13x52x3xf32>
+}
+
 // -----
 // CHECK-LABEL: scatter_unranked_indices
 func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

@lhutton1 lhutton1 requested a review from psunn November 13, 2025 15:49
@lhutton1
Copy link
Contributor Author

cc @udaya-ranga @Tai78641

@Tai78641
Copy link
Contributor

LGTM

@IanTaylerLessa-arm
Copy link
Contributor

Thanks for the quick fix! Also LGTM on my side.

Copy link
Contributor

@psunn psunn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@lhutton1 lhutton1 merged commit 8723fe5 into llvm:main Nov 14, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants