diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 76346a766f1f7..50663ddd27346 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1172,9 +1172,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, Value input = op->getOperand(0); // Figure out the accType if needed - bool widenAccTy = std::is_same_v && - isa(elementTy) && - cast(elementTy).isBF16(); + const bool needsFp32AccTy = + isa(elementTy) && cast(elementTy).isBF16(); + const bool widenAccTy = (std::is_same_v || + std::is_same_v) && + needsFp32AccTy; Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy; SmallVector reduceShape; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index e6bd800a0cf0a..20c93c671f48c 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -983,6 +983,31 @@ func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () { // ----- +// CHECK-LABEL: @reduce_product_bf16 +// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16> +func.func @reduce_product_bf16(%arg0: tensor<5x4xbf16>) -> () { + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32> + // CHECK: [[CST1:%.+]] = arith.constant 1.0 + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST1]]{{.*}}outs([[INIT]] + // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<5xf32>) dimensions = [1] + // CHECK: (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) { + // CHECK: [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32 + // CHECK: [[ACC:%.+]] = arith.mulf [[EXTF]], %[[ARG2]] : f32 + // CHECK: linalg.yield [[ACC]] : f32 + // CHECK: } + // CHECK: [[INIT_RES:%.+]] = tensor.empty() : tensor<5xbf16> + // CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[REDUCE]] : tensor<5xf32>) outs([[INIT_RES]] : tensor<5xbf16>) + // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: bf16): + // CHECK: [[TRUNCF:%.+]] = arith.truncf %[[IN]] : f32 to bf16 + // CHECK: linalg.yield [[TRUNCF]] : bf16 + // CHECK: } + // CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xbf16> into tensor<5x1xbf16> + %0 = tosa.reduce_product %arg0 {axis = 1 : i32} : (tensor<5x4xbf16>) -> tensor<5x1xbf16> + return +} + +// ----- + // CHECK-LABEL: @reduce_float // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32> func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {