[mlir][tosa] Fix bf16 reduction accumulator widening#192045
Conversation
Use f32 accumulator when lowering bf16 arithmetic reductions in `TosaToLinalg`; then truncate the result back to bf16. Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
|
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir-linalg Author: Georgios Pinitas (GeorgeARM) ChangesUse f32 accumulator when lowering bf16 arithmetic reductions in Full diff: https://github.com/llvm/llvm-project/pull/192045.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 76346a766f1f7..50663ddd27346 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1172,9 +1172,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
Value input = op->getOperand(0);
// Figure out the accType if needed
- bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
- isa<FloatType>(elementTy) &&
- cast<FloatType>(elementTy).isBF16();
+ const bool needsFp32AccTy =
+ isa<FloatType>(elementTy) && cast<FloatType>(elementTy).isBF16();
+ const bool widenAccTy = (std::is_same_v<OpTy, tosa::ReduceSumOp> ||
+ std::is_same_v<OpTy, tosa::ReduceProductOp>) &&
+ needsFp32AccTy;
Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
SmallVector<int64_t> reduceShape;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index e6bd800a0cf0a..20c93c671f48c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -983,6 +983,31 @@ func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
// -----
+// CHECK-LABEL: @reduce_product_bf16
+// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
+func.func @reduce_product_bf16(%arg0: tensor<5x4xbf16>) -> () {
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<5xf32>
+ // CHECK: [[CST1:%.+]] = arith.constant 1.0
+ // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST1]]{{.*}}outs([[INIT]]
+ // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<5xf32>) dimensions = [1]
+ // CHECK: (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
+ // CHECK: [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
+ // CHECK: [[ACC:%.+]] = arith.mulf [[EXTF]], %[[ARG2]] : f32
+ // CHECK: linalg.yield [[ACC]] : f32
+ // CHECK: }
+ // CHECK: [[INIT_RES:%.+]] = tensor.empty() : tensor<5xbf16>
+ // CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[REDUCE]] : tensor<5xf32>) outs([[INIT_RES]] : tensor<5xbf16>)
+ // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: bf16):
+ // CHECK: [[TRUNCF:%.+]] = arith.truncf %[[IN]] : f32 to bf16
+ // CHECK: linalg.yield [[TRUNCF]] : bf16
+ // CHECK: }
+ // CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [5, 1] : tensor<5xbf16> into tensor<5x1xbf16>
+ %0 = tosa.reduce_product %arg0 {axis = 1 : i32} : (tensor<5x4xbf16>) -> tensor<5x1xbf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @reduce_float
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
|
| const bool needsFp32AccTy = | ||
| isa<FloatType>(elementTy) && cast<FloatType>(elementTy).isBF16(); | ||
| const bool widenAccTy = (std::is_same_v<OpTy, tosa::ReduceSumOp> || | ||
| std::is_same_v<OpTy, tosa::ReduceProductOp>) && |
There was a problem hiding this comment.
The spec pseudo-code seems to suggest REDUCE_PRODUCT uses an acc type of bf16? (see https://www.mlplatform.org/tosa/tosa_spec_1_0_1.html#_reduce_product), is this an intentional deviation?
There was a problem hiding this comment.
I see; this is a bit brittle then. As this is not the canonical specification definition then probably this is not the best way forward.
There was a problem hiding this comment.
In general an implementation can diverge from the spec pseudo-code as long as it passes conformance. Since this is widening the accumulator type, I don't see an issue here.
It seems the problem previously was that the linalg implementation using a bf16 accumulator type used truncation when multiplying, rather than using round-to-nearest. A round-to-nearest implementation would pass conformance. Perhaps it's worth leaving a comment to this effect so that the legalization can be improved in the future? Otherwise the changes LGTM!
Use f32 accumulator when lowering bf16 arithmetic reductions in
TosaToLinalg; then truncate the result back to bf16.