diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 2b36b2c5113e1..bb8faf01802fa 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -2271,7 +2271,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> { let arguments = (ins Tosa_Tensor3D:$values, - Tosa_Int32Tensor2D:$indices + Tosa_IndexTensor2D:$indices ); let results = (outs @@ -2308,7 +2308,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> { let arguments = (ins Tosa_Tensor3D:$values_in, - Tosa_Int32Tensor2D:$indices, + Tosa_IndexTensor2D:$indices, Tosa_Tensor3D:$input ); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 414b51bf4b135..266a9e3a7d946 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -202,10 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[ def Tosa_TensorUpto4D : AnyTypeOf<[ Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>; -def Tosa_Int32TensorUpto4D : AnyTypeOf<[ - Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>; -def Tosa_Int32Tensor2D : AnyTypeOf<[ - Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>; +def Tosa_IndexTensor2D : AnyTypeOf<[ + Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>; def Tosa_TensorAtLeast1D : AnyTypeOf<[ Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">; diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index b2a71ab882230..a4591f7ffd393 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -757,10 +757,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> } // ----- -// CHECK-LABEL: scatter -func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> { - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32> - return %0 : tensor<13x52x3xf32> +// CHECK-LABEL: gather_int64 +func.func @test_gather_int64(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xf32> { + %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi64>) -> tensor<13x26x3xf32> + return %0 : tensor<13x26x3xf32> } // ----- @@ -770,6 +770,20 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso return %0 : tensor<13x26x3xf32> } +// ----- +// CHECK-LABEL: scatter +func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> { + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32> + return %0 : tensor<13x52x3xf32> +} + +// ----- +// CHECK-LABEL: scatter_int64 +func.func @test_scatter_int64(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> { + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi64>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32> + return %0 : tensor<13x52x3xf32> +} + // ----- // CHECK-LABEL: scatter_unranked_indices func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {