diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp index 76ee9d00b53ad..446646913a4e1 100644 --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -403,6 +403,56 @@ struct IsFiniteOpLowering } }; +// A `clampf` is converted into `minimum(value, max)` followed by +// `maximum(result, min)`, i.e. clampf(x, lo, hi) = maximum(minimum(x, hi), lo) +struct ClampFOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + math::ClampFOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto &typeConverter = *this->getTypeConverter(); + auto operandType = adaptor.getValue().getType(); + auto llvmOperandType = typeConverter.convertType(operandType); + if (!llvmOperandType) + return failure(); + + auto loc = op.getLoc(); + ConvertFastMath minAttrs(op); + ConvertFastMath maxAttrs(op); + + if (!isa(llvmOperandType)) { + auto minOp = LLVM::MinimumOp::create( + rewriter, loc, llvmOperandType, + ValueRange{adaptor.getValue(), adaptor.getMax()}, + minAttrs.getAttrs()); + rewriter.replaceOpWithNewOp( + op, llvmOperandType, ValueRange{minOp.getResult(), adaptor.getMin()}, + maxAttrs.getAttrs()); + return success(); + } + + if (!isa(op.getResult().getType())) + return rewriter.notifyMatchFailure(op, "expected vector result type"); + + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), typeConverter, + [&](Type llvm1DVectorTy, ValueRange operands) { + // operands order: value, min, max + auto minOp = LLVM::MinimumOp::create( + rewriter, loc, llvm1DVectorTy, + ValueRange{operands[0], operands[2]}, minAttrs.getAttrs()); + return LLVM::MaximumOp::create( + rewriter, loc, llvm1DVectorTy, + ValueRange{minOp.getResult(), operands[1]}, maxAttrs.getAttrs()); + }, + rewriter); + } +}; + struct ConvertMathToLLVMPass : public impl::ConvertMathToLLVMPassBase { using Base::Base; @@ -431,6 +481,7 @@ void mlir::populateMathToLLVMConversionPatterns( AbsFOpLowering, AbsIOpLowering, CeilOpLowering, + ClampFOpLowering, CopySignOpLowering, CosOpLowering, CoshOpLowering, diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir index 504dc1afb0eef..301a95bf716b7 100644 --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -641,3 +641,50 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: f4E2M1FN, %arg2: f4E2M1FN %2 = math.fma %arg1, %arg1, %arg2 : f4E2M1FN return } + +// ----- + +// CHECK-LABEL: func @clampf( +// CHECK-SAME: %[[VAL:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32 +func.func @clampf(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 { + // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[VAL]], %[[MAX]]) : (f32, f32) -> f32 + // CHECK: %[[RESULT:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[MIN]]) : (f32, f32) -> f32 + %0 = math.clampf %arg0 to [%arg1, %arg2] : f32 + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @clampf_fmf( +// CHECK-SAME: %[[VAL:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32 +func.func @clampf_fmf(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 { + // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[VAL]], %[[MAX]]) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 + // CHECK: %[[RESULT:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[MIN]]) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 + %0 = math.clampf %arg0 to [%arg1, %arg2] fastmath : f32 + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: func @clampf_vector( +// CHECK-SAME: %[[VAL:.*]]: vector<4xf32>, %[[MIN:.*]]: vector<4xf32>, %[[MAX:.*]]: vector<4xf32> +func.func @clampf_vector(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> { + // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[VAL]], %[[MAX]]) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32> + // CHECK: %[[RESULT:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[MIN]]) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32> + %0 = math.clampf %arg0 to [%arg1, %arg2] : vector<4xf32> + return %0 : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @clampf_2dvector( +func.func @clampf_2dvector(%arg0: vector<4x3xf32>, %arg1: vector<4x3xf32>, %arg2: vector<4x3xf32>) -> vector<4x3xf32> { + // CHECK: %[[EXTRACT_VAL:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + // CHECK: %[[EXTRACT_MIN:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + // CHECK: %[[EXTRACT_MAX:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[EXTRACT_VAL]], %[[EXTRACT_MAX]]) : (vector<3xf32>, vector<3xf32>) -> vector<3xf32> + // CHECK: %[[MAX_VAL:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[EXTRACT_MIN]]) : (vector<3xf32>, vector<3xf32>) -> vector<3xf32> + // CHECK: llvm.insertvalue %[[MAX_VAL]], %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + %0 = math.clampf %arg0 to [%arg1, %arg2] : vector<4x3xf32> + return %0 : vector<4x3xf32> +}