diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index caf80165fc640..99b7cda49094e 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1001,8 +1001,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { !outputTy.hasStaticShape()) return {}; - if (inputTy.getDimSize(getAxis()) == 1) - return DenseElementsAttr::get(outputTy, 0); + const Type outputElementTy = getElementTypeOrSelf(outputTy); + if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) { + const auto outputElemIntTy = cast(outputElementTy); + const APInt zero = APInt::getZero(outputElemIntTy.getWidth()); + return DenseElementsAttr::get(outputTy, zero); + } return {}; } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index e8525a5d2ed62..7574afa215e78 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -9,6 +9,15 @@ func.func @argmax_nofold(%arg0: tensor) -> tensor<1xi32> { // ----- +// CHECK-LABEL: @test_argmax_fold_i64_index +func.func @test_argmax_fold_i64_index(%arg0: tensor<1xi8>) -> tensor { + // CHECK: "tosa.const"() <{values = dense<0> : tensor}> : () -> tensor + %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<1xi8>) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @pad_wh_avg_pool2d_fold func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> { // CHECK-NOT: tosa.pad