diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index 18ac0dbc8d13e..5a07a01d0928a 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -68,6 +68,11 @@ void populateExpandF8E8M0Patterns(RewritePatternSet &patterns); /// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns); +/// Add patterns to expand `arith.flush_denormals` into integer arithmetic +/// (bitcast + bit masks + compare + select). Only matches IEEE-like +/// floating-point types. +void populateExpandFlushDenormalsPatterns(RewritePatternSet &patterns); + /// Add patterns to expand Arith ops. void populateArithExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index c7370b83fdb6c..27e9146ec3606 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -21,6 +21,10 @@ def ArithExpandOpsPass : Pass<"arith-expand"> { "Enable the F8E8M0 expansion patterns">, Option<"includeF4E2M1", "include-f4e2m1", "bool", /*default=*/"false", "Enable the F4E2M1 expansion patterns">, + Option<"includeFlushDenormals", "include-flush-denormals", "bool", + /*default=*/"false", + "Enable expansion of `arith.flush_denormals` on IEEE-like " + "floating-point types">, ]; } diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 46f8c1037d47b..c9217c57a5f25 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -34,6 +34,17 @@ static Value createConst(Location loc, Type type, int value, return arith::ConstantOp::create(rewriter, loc, attr); } +/// Create an integer constant from an APInt. +static Value createAPIntConst(Location loc, Type type, const APInt &value, + PatternRewriter &rewriter) { + auto attr = IntegerAttr::get(getElementTypeOrSelf(type), value); + if (auto shapedTy = dyn_cast(type)) { + return arith::ConstantOp::create(rewriter, loc, + DenseElementsAttr::get(shapedTy, attr)); + } + return arith::ConstantOp::create(rewriter, loc, attr); +} + /// Create a float constant. static Value createFloatConst(Location loc, Type type, const APFloat &value, PatternRewriter &rewriter) { @@ -729,6 +740,81 @@ struct ScalingTruncFOpConverter } }; +/// Expands `arith.flush_denormals` into integer arithmetic. +/// +/// For an IEEE-like floating-point value with a sign|exponent|mantissa bit +/// layout, a value is denormal iff its biased exponent field is zero and its +/// stored mantissa is non-zero. When the exponent field is zero, the value is +/// either pos/neg 0 (mantissa = 0) or a denormal (mantissa != 0); in both +/// cases, clearing the mantissa bits produces the desired sign-preserved zero +/// (a no-op for pos/neg 0, a flush for denormals). When the exponent field is +/// non-zero, the value passes through unchanged. +/// +/// Pseudocode: +/// bits = bitcast(x, iN) +/// expIsZero = (bits & expMask) == 0 +/// cleared = bits & ~manMask +/// resultBits = select(expIsZero, cleared, bits) +/// result = bitcast(resultBits, floatTy) +struct FlushDenormalsOpConverter + : public OpRewritePattern { + using Base::Base; + LogicalResult matchAndRewrite(arith::FlushDenormalsOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + Value operand = op.getOperand(); + Type operandTy = operand.getType(); + auto floatTy = dyn_cast(getElementTypeOrSelf(operandTy)); + if (!floatTy) + return rewriter.notifyMatchFailure(op, "operand is not a float type"); + + const llvm::fltSemantics &sem = floatTy.getFloatSemantics(); + // Restrict to IEEE-like encodings, where the sign bit is the MSB and + // denormals are exactly "biased exponent == 0 and non-zero mantissa". + if (!llvm::APFloatBase::isIEEELikeFP(sem)) + return rewriter.notifyMatchFailure( + op, "only IEEE-like floating-point types are supported"); + + unsigned totalBits = llvm::APFloatBase::semanticsSizeInBits(sem); + unsigned precision = llvm::APFloatBase::semanticsPrecision(sem); + // Stored mantissa bits = precision - 1 (implicit leading bit not stored). + // Exponent field bits = totalBits - 1 (sign) - storedMantissa. + if (precision < 1 || precision > totalBits) + return rewriter.notifyMatchFailure(op, "unexpected float semantics"); + unsigned mantissaBits = precision - 1; + unsigned expBits = totalBits - 1 - mantissaBits; + if (expBits == 0 || mantissaBits == 0) + return rewriter.notifyMatchFailure( + op, "degenerate float encoding has no exponent or mantissa"); + + Type intTy = + cloneToShapedType(operandTy, rewriter.getIntegerType(totalBits)); + Value bits = arith::BitcastOp::create(b, intTy, operand); + APInt expMaskVal = + APInt::getBitsSet(totalBits, mantissaBits, mantissaBits + expBits); + APInt clearMantissaMaskVal = ~APInt::getLowBitsSet(totalBits, mantissaBits); + APInt zeroVal = APInt::getZero(totalBits); + Value expMask = createAPIntConst(loc, intTy, expMaskVal, rewriter); + Value clearMantissaMask = + createAPIntConst(loc, intTy, clearMantissaMaskVal, rewriter); + Value zero = createAPIntConst(loc, intTy, zeroVal, rewriter); + + // expField == 0 + Value expField = arith::AndIOp::create(b, bits, expMask); + Value expIsZero = + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, expField, zero); + + // Clear mantissa bits: when exp == 0, this produces pos/neg 0.0. + Value cleared = arith::AndIOp::create(b, bits, clearMantissaMask); + Value resultBits = arith::SelectOp::create(b, expIsZero, cleared, bits); + Value result = arith::BitcastOp::create(b, operandTy, resultBits); + + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ArithExpandOpsPass : public arith::impl::ArithExpandOpsPassBase { using ArithExpandOpsPassBase::ArithExpandOpsPassBase; @@ -765,6 +851,20 @@ struct ArithExpandOpsPass arith::populateExpandF8E8M0Patterns(patterns); if (includeF4E2M1) arith::populateExpandF4E2M1Patterns(patterns); + if (includeFlushDenormals) { + arith::populateExpandFlushDenormalsPatterns(patterns); + // Only IEEE-like floating-point types are expanded by the pattern; + // leave `arith.flush_denormals` on other types alone. + target.addDynamicallyLegalOp( + [](arith::FlushDenormalsOp op) { + auto floatTy = + dyn_cast(getElementTypeOrSelf(op.getType())); + if (!floatTy) + return true; + return !llvm::APFloatBase::isIEEELikeFP( + floatTy.getFloatSemantics()); + }); + } target.addDynamicallyLegalOp( [=](arith::ExtFOp op) { @@ -831,6 +931,11 @@ void mlir::arith::populateExpandScalingExtTruncPatterns( patterns.getContext()); } +void mlir::arith::populateExpandFlushDenormalsPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { populateCeilFloorDivExpandOpsPatterns(patterns); populateExpandScalingExtTruncPatterns(patterns); diff --git a/mlir/test/Dialect/Arith/expand-flush-denormals.mlir b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir new file mode 100644 index 0000000000000..69a8e6b789cc1 --- /dev/null +++ b/mlir/test/Dialect/Arith/expand-flush-denormals.mlir @@ -0,0 +1,96 @@ +// RUN: mlir-opt %s -arith-expand=include-flush-denormals=true -split-input-file | FileCheck %s + +// Expansion for f32: +// exp mask = 0x7f800000 (sign 0, exp all 1, mantissa 0) +// clear-man mask = 0xff800000 (sign 1, exp all 1, mantissa 0) +// When the exponent field is zero (±0 or denormal), the mantissa bits are +// cleared, yielding a sign-preserved zero. Otherwise the bits pass through. + +// CHECK-LABEL: func @flush_denormals_f32 +// CHECK-SAME: (%[[ARG0:.+]]: f32) -> f32 +// CHECK: %[[BITS:.+]] = arith.bitcast %[[ARG0]] : f32 to i32 +// CHECK: %[[EXP_MASK:.+]] = arith.constant 2139095040 : i32 +// CHECK: %[[CLEAR_MAN_MASK:.+]] = arith.constant -8388608 : i32 +// CHECK: %[[ZERO:.+]] = arith.constant 0 : i32 +// CHECK: %[[EXP:.+]] = arith.andi %[[BITS]], %[[EXP_MASK]] : i32 +// CHECK: %[[EXP_ZERO:.+]] = arith.cmpi eq, %[[EXP]], %[[ZERO]] : i32 +// CHECK: %[[CLEARED:.+]] = arith.andi %[[BITS]], %[[CLEAR_MAN_MASK]] : i32 +// CHECK: %[[RES_BITS:.+]] = arith.select %[[EXP_ZERO]], %[[CLEARED]], %[[BITS]] : i32 +// CHECK: %[[RES:.+]] = arith.bitcast %[[RES_BITS]] : i32 to f32 +// CHECK: return %[[RES]] : f32 +func.func @flush_denormals_f32(%arg0: f32) -> f32 { + %0 = arith.flush_denormals %arg0 : f32 + return %0 : f32 +} + +// ----- + +// Expansion for bf16: +// exp mask = 0x7f80 +// clear-man mask = 0xff80 (-128 as signed i16) + +// CHECK-LABEL: func @flush_denormals_bf16 +// CHECK: arith.bitcast %{{.*}} : bf16 to i16 +// CHECK: %[[EXP_MASK:.+]] = arith.constant 32640 : i16 +// CHECK: %[[CLEAR_MAN_MASK:.+]] = arith.constant -128 : i16 +// CHECK: arith.bitcast %{{.*}} : i16 to bf16 +func.func @flush_denormals_bf16(%arg0: bf16) -> bf16 { + %0 = arith.flush_denormals %arg0 : bf16 + return %0 : bf16 +} + +// ----- + +// Expansion for f16: +// exp mask = 0x7c00 +// clear-man mask = 0xfc00 (-1024 as signed i16) + +// CHECK-LABEL: func @flush_denormals_f16 +// CHECK: arith.bitcast %{{.*}} : f16 to i16 +// CHECK: %[[EXP_MASK:.+]] = arith.constant 31744 : i16 +// CHECK: %[[CLEAR_MAN_MASK:.+]] = arith.constant -1024 : i16 +// CHECK: arith.bitcast %{{.*}} : i16 to f16 +func.func @flush_denormals_f16(%arg0: f16) -> f16 { + %0 = arith.flush_denormals %arg0 : f16 + return %0 : f16 +} + +// ----- + +// Expansion for f64 (verifies wide APInt masks work): +// exp mask = 0x7ff0000000000000 = 9218868437227405312 +// clear-man mask = 0xfff0000000000000 = -4503599627370496 (signed i64) + +// CHECK-LABEL: func @flush_denormals_f64 +// CHECK: arith.bitcast %{{.*}} : f64 to i64 +// CHECK: %[[EXP_MASK:.+]] = arith.constant 9218868437227405312 : i64 +// CHECK: %[[CLEAR_MAN_MASK:.+]] = arith.constant -4503599627370496 : i64 +// CHECK: arith.bitcast %{{.*}} : i64 to f64 +func.func @flush_denormals_f64(%arg0: f64) -> f64 { + %0 = arith.flush_denormals %arg0 : f64 + return %0 : f64 +} + +// ----- + +// CHECK-LABEL: func @flush_denormals_vector +// CHECK: arith.bitcast %{{.*}} : vector<4xf32> to vector<4xi32> +// CHECK: arith.andi %{{.*}} : vector<4xi32> +// CHECK: arith.cmpi eq, %{{.*}} : vector<4xi32> +// CHECK: arith.andi %{{.*}} : vector<4xi32> +// CHECK: arith.select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xi32> +// CHECK: arith.bitcast %{{.*}} : vector<4xi32> to vector<4xf32> +func.func @flush_denormals_vector(%arg0: vector<4xf32>) -> vector<4xf32> { + %0 = arith.flush_denormals %arg0 : vector<4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @flush_denormals_tensor +// CHECK: arith.bitcast %{{.*}} : tensor<8xf32> to tensor<8xi32> +// CHECK: arith.bitcast %{{.*}} : tensor<8xi32> to tensor<8xf32> +func.func @flush_denormals_tensor(%arg0: tensor<8xf32>) -> tensor<8xf32> { + %0 = arith.flush_denormals %arg0 : tensor<8xf32> + return %0 : tensor<8xf32> +}