From eab003fdc0fbc6372d0575c776e7b98772f73e78 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 15 Oct 2025 17:06:23 +0100 Subject: [PATCH] [mlir][tosa] Fix argmax folder when output type is i64 Previously the following IR: ``` tosa.argmax %arg0 {axis = 0 : i32} : (tensor<1xi8>) -> tensor ``` Would result in a crash with the assertion: ``` expected dense element bit width 64 to match data size 32 for type i64 ``` This commit ensures that zero is constructed with the correct bitwidth while folding, therefore fixing the crash. Change-Id: I4531d9f36fb2e682d46075229ad15f12f53d7cf5 --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 8 ++++++-- mlir/test/Dialect/Tosa/canonicalize.mlir | 9 +++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index caf80165fc640..99b7cda49094e 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1001,8 +1001,12 @@ OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { !outputTy.hasStaticShape()) return {}; - if (inputTy.getDimSize(getAxis()) == 1) - return DenseElementsAttr::get(outputTy, 0); + const Type outputElementTy = getElementTypeOrSelf(outputTy); + if (inputTy.getDimSize(getAxis()) == 1 && outputElementTy.isInteger()) { + const auto outputElemIntTy = cast(outputElementTy); + const APInt zero = APInt::getZero(outputElemIntTy.getWidth()); + return DenseElementsAttr::get(outputTy, zero); + } return {}; } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index e8525a5d2ed62..7574afa215e78 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -9,6 +9,15 @@ func.func @argmax_nofold(%arg0: tensor) -> tensor<1xi32> { // ----- +// CHECK-LABEL: @test_argmax_fold_i64_index +func.func @test_argmax_fold_i64_index(%arg0: tensor<1xi8>) -> tensor { + // CHECK: "tosa.const"() <{values = dense<0> : tensor}> : () -> tensor + %0 = tosa.argmax %arg0 {axis = 0 : i32} : (tensor<1xi8>) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: @pad_wh_avg_pool2d_fold func.func @pad_wh_avg_pool2d_fold(%input: tensor<1x10x8x3xf32>) -> tensor<1x6x5x3xf32> { // CHECK-NOT: tosa.pad