diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index 9bea594d87df6..6c9d02c273e5b 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -30,6 +30,14 @@ struct VecOpToScalarOp : public OpRewritePattern { LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; }; +// Pattern to promote an op of a smaller floating point type to F32. +template +struct PromoteOpToF32 : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final; +}; // Pattern to convert scalar math operations to calls to libm functions. // Additionally the libm function signatures are declared. template @@ -82,13 +90,30 @@ VecOpToScalarOp::matchAndRewrite(Op op, PatternRewriter &rewriter) const { return success(); } +template +LogicalResult +PromoteOpToF32::matchAndRewrite(Op op, PatternRewriter &rewriter) const { + auto opType = op.getType(); + if (!opType.template isa()) + return failure(); + + auto loc = op.getLoc(); + auto f32 = rewriter.getF32Type(); + auto extendedOperands = llvm::to_vector( + llvm::map_range(op->getOperands(), [&](Value operand) -> Value { + return rewriter.create(loc, f32, operand); + })); + auto newOp = rewriter.create(loc, f32, extendedOperands); + rewriter.replaceOpWithNewOp(op, opType, newOp); + return success(); +} + template LogicalResult ScalarOpToLibmCall::matchAndRewrite(Op op, PatternRewriter &rewriter) const { auto module = SymbolTable::getNearestSymbolTable(op); auto type = op.getType(); - // TODO: Support Float16 by upcasting to Float32 if (!type.template isa()) return failure(); @@ -117,6 +142,8 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add, VecOpToScalarOp, VecOpToScalarOp>(patterns.getContext(), benefit); + patterns.add, PromoteOpToF32, + PromoteOpToF32>(patterns.getContext(), benefit); patterns.add>(patterns.getContext(), "atan2f", "atan2", benefit); patterns.add>(patterns.getContext(), "erff", diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir index 57af89badd635..7cdb56e783e74 100644 --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -25,13 +25,25 @@ func.func @tanh_caller(%float: f32, %double: f64) -> (f32, f64) { // CHECK-LABEL: func @atan2_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 // CHECK-SAME: %[[DOUBLE:.*]]: f64 -func.func @atan2_caller(%float: f32, %double: f64) -> (f32, f64) { - // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32 +// CHECK-SAME: %[[HALF:.*]]: f16 +// CHECK-SAME: %[[BFLOAT:.*]]: bf16 +func.func @atan2_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16) -> (f32, f64, f16, bf16) { + // CHECK: %[[FLOAT_RESULT:.*]] = call @atan2f(%[[FLOAT]], %[[FLOAT]]) : (f32, f32) -> f32 %float_result = math.atan2 %float, %float : f32 - // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64 + // CHECK: %[[DOUBLE_RESULT:.*]] = call @atan2(%[[DOUBLE]], %[[DOUBLE]]) : (f64, f64) -> f64 %double_result = math.atan2 %double, %double : f64 - // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] - return %float_result, %double_result : f32, f64 + // CHECK: %[[HALF_PROMOTED1:.*]] = arith.extf %[[HALF]] : f16 to f32 + // CHECK: %[[HALF_PROMOTED2:.*]] = arith.extf %[[HALF]] : f16 to f32 + // CHECK: %[[HALF_CALL:.*]] = call @atan2f(%[[HALF_PROMOTED1]], %[[HALF_PROMOTED2]]) : (f32, f32) -> f32 + // CHECK: %[[HALF_RESULT:.*]] = arith.truncf %[[HALF_CALL]] : f32 to f16 + %half_result = math.atan2 %half, %half : f16 + // CHECK: %[[BFLOAT_PROMOTED1:.*]] = arith.extf %[[BFLOAT]] : bf16 to f32 + // CHECK: %[[BFLOAT_PROMOTED2:.*]] = arith.extf %[[BFLOAT]] : bf16 to f32 + // CHECK: %[[BFLOAT_CALL:.*]] = call @atan2f(%[[BFLOAT_PROMOTED1]], %[[BFLOAT_PROMOTED2]]) : (f32, f32) -> f32 + // CHECK: %[[BFLOAT_RESULT:.*]] = arith.truncf %[[BFLOAT_CALL]] : f32 to bf16 + %bfloat_result = math.atan2 %bfloat, %bfloat : bf16 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]], %[[HALF_RESULT]], %[[BFLOAT_RESULT]] + return %float_result, %double_result, %half_result, %bfloat_result : f32, f64, f16, bf16 } // CHECK-LABEL: func @erf_caller