Skip to content

[mlir][tosa] Fix bf16 reduction accumulator widening#192045

Closed
GeorgeARM wants to merge 1 commit into
llvm:mainfrom
GeorgeARM:fix-reduce-prod
Closed

[mlir][tosa] Fix bf16 reduction accumulator widening#192045
GeorgeARM wants to merge 1 commit into
llvm:mainfrom
GeorgeARM:fix-reduce-prod

Conversation

@GeorgeARM
Copy link
Copy Markdown
Contributor

Use f32 accumulator when lowering bf16 arithmetic reductions in TosaToLinalg; then truncate the result back to bf16.

Use f32 accumulator when lowering bf16 arithmetic reductions in
`TosaToLinalg`; then truncate the result back to bf16.

Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Apr 14, 2026

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir-linalg

Author: Georgios Pinitas (GeorgeARM)

Changes

Use f32 accumulator when lowering bf16 arithmetic reductions in TosaToLinalg; then truncate the result back to bf16.


Full diff: https://github.com/llvm/llvm-project/pull/192045.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-3)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+25)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 76346a766f1f7..50663ddd27346 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1172,9 +1172,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   Value input = op->getOperand(0);
 
   // Figure out the accType if needed
-  bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
-                    isa<FloatType>(elementTy) &&
-                    cast<FloatType>(elementTy).isBF16();
+  const bool needsFp32AccTy =
+      isa<FloatType>(elementTy) && cast<FloatType>(elementTy).isBF16();
+  const bool widenAccTy = (std::is_same_v<OpTy, tosa::ReduceSumOp> ||
+                           std::is_same_v<OpTy, tosa::ReduceProductOp>) &&
+                          needsFp32AccTy;
   Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
 
   SmallVector<int64_t> reduceShape;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index e6bd800a0cf0a..20c93c671f48c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -983,6 +983,31 @@ func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
 
 // -----
 
+// CHECK-LABEL: @reduce_product_bf16
+// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
+func.func @reduce_product_bf16(%arg0: tensor<5x4xbf16>) -> () {
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
+  // CHECK: [[CST1:%.+]] = arith.constant 1.0
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST1]]{{.*}}outs([[INIT]]
+  // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<5xf32>) dimensions = [1]
+  // CHECK:  (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
+  // CHECK:   [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
+  // CHECK:   [[ACC:%.+]] = arith.mulf [[EXTF]], %[[ARG2]] : f32
+  // CHECK:   linalg.yield [[ACC]] : f32
+  // CHECK:  }
+  // CHECK: [[INIT_RES:%.+]] = tensor.empty() : tensor<5xbf16>
+  // CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[REDUCE]] : tensor<5xf32>) outs([[INIT_RES]] : tensor<5xbf16>)
+  // CHECK:  ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: bf16):
+  // CHECK:   [[TRUNCF:%.+]] = arith.truncf %[[IN]] : f32 to bf16
+  // CHECK:   linalg.yield [[TRUNCF]] : bf16
+  // CHECK:  }
+  // CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xbf16> into tensor<5x1xbf16>
+  %0 = tosa.reduce_product %arg0 {axis = 1 : i32} : (tensor<5x4xbf16>) -> tensor<5x1xbf16>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @reduce_float
 // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
 func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {

const bool needsFp32AccTy =
isa<FloatType>(elementTy) && cast<FloatType>(elementTy).isBF16();
const bool widenAccTy = (std::is_same_v<OpTy, tosa::ReduceSumOp> ||
std::is_same_v<OpTy, tosa::ReduceProductOp>) &&
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spec pseudo-code seems to suggest REDUCE_PRODUCT uses an acc type of bf16? (see https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_reduce_product), is this an intentional deviation?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see; this is a bit brittle then. As this is not the canonical specification definition then probably this is not the best way forward.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general an implementation can diverge from the spec pseudo-code as long as it passes conformance. Since this is widening the accumulator type, I don't see an issue here.

It seems the problem previously was that the linalg implementation using a bf16 accumulator type used truncation when multiplying, rather than using round-to-nearest. A round-to-nearest implementation would pass conformance. Perhaps it's worth leaving a comment to this effect so that the legalization can be improved in the future? Otherwise the changes LGTM!

@GeorgeARM GeorgeARM closed this Apr 17, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants