diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 1b5a8728dd3f8..93251016966a7 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -122,6 +122,21 @@ static bool isBoolScalarOrVector(Type type) { return false; } +/// Converts arith fast-math flags to SPIR-V FPFastMathMode flags. +static spirv::FPFastMathMode +convertArithFastMathFlagsToSPIRV(arith::FastMathFlags arithFMF) { + spirv::FPFastMathMode spirvFMF = spirv::FPFastMathMode::None; + if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::nnan)) + spirvFMF = spirvFMF | spirv::FPFastMathMode::NotNaN; + if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::ninf)) + spirvFMF = spirvFMF | spirv::FPFastMathMode::NotInf; + if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::nsz)) + spirvFMF = spirvFMF | spirv::FPFastMathMode::NSZ; + if (bitEnumContainsAll(arithFMF, arith::FastMathFlags::arcp)) + spirvFMF = spirvFMF | spirv::FPFastMathMode::AllowRecip; + return spirvFMF; +} + /// Creates a scalar/vector integer constant. static Value getScalarOrVectorConstInt(Type type, uint64_t value, OpBuilder &builder, Location loc) { @@ -225,6 +240,42 @@ struct ElementwiseArithOpPattern final : OpConversionPattern { } }; +/// Converts elementwise unary, binary and ternary floating-point arith +/// operations to SPIR-V operations, propagating fast-math flags as +/// FPFastMathMode decorations. +template +struct ElementwiseFPOpPattern final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(adaptor.getOperands().size() <= 3); + Type dstType = this->getTypeConverter()->convertType(op.getType()); + if (!dstType) { + return rewriter.notifyMatchFailure( + op->getLoc(), + llvm::formatv("failed to convert type {0} for SPIR-V", op.getType())); + } + + auto newOp = rewriter.template replaceOpWithNewOp( + op, dstType, adaptor.getOperands()); + + auto *converter = this->template getTypeConverter(); + if (!converter->getTargetEnv().allows(spirv::Capability::Kernel)) + return success(); + + spirv::FPFastMathMode spirvFMF = + convertArithFastMathFlagsToSPIRV(op.getFastmath()); + if (spirvFMF != spirv::FPFastMathMode::None) { + newOp->setAttr("fp_fast_math_mode", + spirv::FPFastMathModeAttr::get(op.getContext(), spirvFMF)); + } + + return success(); + } +}; + //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// @@ -1530,12 +1581,12 @@ void mlir::arith::populateArithToSPIRVPatterns( spirv::ElementwiseOpPattern, ShRSIBoolPattern, // shrsi(a,b) = a (identity; see pattern comment) spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, - spirv::ElementwiseOpPattern, + ElementwiseFPOpPattern, + ElementwiseFPOpPattern, + ElementwiseFPOpPattern, + ElementwiseFPOpPattern, + ElementwiseFPOpPattern, + ElementwiseFPOpPattern, ExtUIPattern, ExtUII1Pattern, ExtSIPattern, ExtSII1Pattern, TypeCastingOpPattern, diff --git a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir index 9bbe28fb127a7..8355cc4e8eead 100644 --- a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir @@ -67,3 +67,90 @@ func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32 } } // end module + +// ----- + +// FPFastMathMode decoration tests (requires Kernel capability) + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: @addf_fast_math +func.func @addf_fast_math(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode} : f32 + %0 = arith.addf %arg0, %arg1 fastmath : f32 + return %0: f32 +} + +// CHECK-LABEL: @mulf_no_fast_math +func.func @mulf_no_fast_math(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: spirv.FMul %{{.*}}, %{{.*}} : f32 + // CHECK-NOT: fp_fast_math_mode + %0 = arith.mulf %arg0, %arg1 : f32 + return %0: f32 +} + +// CHECK-LABEL: @subf_all_flags +func.func @subf_all_flags(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: spirv.FSub %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode} : f32 + %0 = arith.subf %arg0, %arg1 fastmath : f32 + return %0: f32 +} + +// CHECK-LABEL: @negf_fast_math +func.func @negf_fast_math(%arg0 : f32) -> f32 { + // CHECK: spirv.FNegate %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode} : f32 + %0 = arith.negf %arg0 fastmath : f32 + return %0: f32 +} + +// CHECK-LABEL: @divf_fast_math +func.func @divf_fast_math(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: spirv.FDiv %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode} : f32 + %0 = arith.divf %arg0, %arg1 fastmath : f32 + return %0: f32 +} + +// CHECK-LABEL: @remf_fast_math +func.func @remf_fast_math(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: spirv.FRem %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode} : f32 + %0 = arith.remf %arg0, %arg1 fastmath : f32 + return %0: f32 +} + +// Test that unsupported flags (reassoc, contract, afn) are silently dropped +// CHECK-LABEL: @addf_unsupported_flags_only +func.func @addf_unsupported_flags_only(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: spirv.FAdd %{{.*}}, %{{.*}} : f32 + // CHECK-NOT: fp_fast_math_mode + %0 = arith.addf %arg0, %arg1 fastmath : f32 + return %0: f32 +} + +// CHECK-LABEL: @addf_vector_fast_math +func.func @addf_vector_fast_math(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> { + // CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode} : vector<4xf32> + %0 = arith.addf %arg0, %arg1 fastmath : vector<4xf32> + return %0: vector<4xf32> +} + +} // end module + +// ----- + +// FPFastMathMode decoration requires the Kernel capability. Without it the decoration is dropped. + +module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +// CHECK-LABEL: @addf_fast_math_no_kernel +func.func @addf_fast_math_no_kernel(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: spirv.FAdd %{{.*}}, %{{.*}} : f32 + // CHECK-NOT: fp_fast_math_mode + %0 = arith.addf %arg0, %arg1 fastmath : f32 + return %0: f32 +} + +} // end module