diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h index 9d9d38cc066aa..e2cd0632b03d4 100644 --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -14,6 +14,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/Builders.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { @@ -58,7 +59,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { std::is_base_of, SourceOp>::value, "expected single result op"); - if (op->getResultTypes().front() != op->getOperand(0).getType()) + auto originalResultType = op->getResult(0).getType(); + if (originalResultType != op->getOperand(0).getType()) return rewriter.notifyMatchFailure( op, "expected op with same operand and result types"); @@ -68,14 +70,38 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { } SmallVector castedOperands; - for (Value operand : adaptor.getOperands()) + for (auto [index, operand] : llvm::enumerate(adaptor.getOperands())) { + // Only for math.ipowi and math.fpowi, the second operand must be an + // integer + if constexpr (std::is_same_v || + std::is_same_v) { + if (index == 1 && isa(operand.getType())) { + auto bitwidth = operand.getType().getIntOrFloatBitWidth(); + assert(bitwidth <= 32 && "expected integer type with bitwidth <= 32"); + if (bitwidth < 32) { + // extend the integer to i32: + operand = rewriter.create( + operand.getLoc(), rewriter.getIntegerType(32), operand); + castedOperands.push_back(operand); + } else { + castedOperands.push_back(operand); + } + continue; + } + } castedOperands.push_back(maybeCast(operand, rewriter)); + } Type resultType = castedOperands.front().getType(); Type funcType = getFunctionType(resultType, castedOperands); - StringRef funcName = - getFunctionName(cast(funcType).getReturnType(), - op.getFastmath()); + + auto fastmath = arith::FastMathFlags::none; + if constexpr (!std::is_same_v) { + fastmath = op.getFastmath(); + } + + StringRef funcName = getFunctionName( + cast(funcType).getReturnType(), fastmath); if (funcName.empty()) return failure(); @@ -88,6 +114,14 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { return success(); } + if (isa(originalResultType)) { + // Cast result from f64 to i32: + Value siOp = rewriter.create( + op->getLoc(), originalResultType, callOp.getResult()); + rewriter.replaceOp(op, {siOp}); + return success(); + } + Value truncated = rewriter.create( op->getLoc(), adaptor.getOperands().front().getType(), callOp.getResult()); @@ -98,6 +132,14 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { private: Value maybeCast(Value operand, PatternRewriter &rewriter) const { Type type = operand.getType(); + + if (isa(type)) { + // cast it to double: + if (!f64Func.empty()) + return rewriter.create( + operand.getLoc(), Float64Type::get(rewriter.getContext()), operand); + } + if (!isa(type)) return operand; @@ -115,6 +157,13 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern { } StringRef getFunctionName(Type type, arith::FastMathFlags flag) const { + // Delegate integer functions to f64Func. + if (isa(type)) { + assert(!f64Func.empty() && + "expected f64Func to be set for integer types"); + return f64Func; + } + if (isa(type)) return f16Func; if (isa(type)) { diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 838eef30a938f..bca5f6a1996bc 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -115,6 +115,8 @@ void mlir::populateMathToROCDLConversionPatterns( "__ocml_erf_f64", "__ocml_erf_f16"); populateOpPatterns(converter, patterns, "__ocml_pown_f32", "__ocml_pown_f64", "__ocml_pown_f16"); + populateOpPatterns(converter, patterns, "__ocml_pown_f64", + "__ocml_pown_f64", "__ocml_pown_f64"); // Single arith pattern that needs a ROCDL call, probably not // worth creating a separate pass for it. populateOpPatterns(converter, patterns, "__ocml_fmod_f32", diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index e4b2f01d6544a..c7ce7957493ae 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -502,6 +502,25 @@ module @test_module { // ----- +module @test_module { + // CHECK: llvm.func @__ocml_pown_f64(f64, i32) -> f64 + // CHECK-LABEL: func @math_ipowi + func.func @math_ipowi(%arg16: i16, %arg16_1: i16, %arg32: i32, %arg32_1: i32) -> (i16, i32) { + // CHECK: %[[F16:.*]] = llvm.sitofp %{{.*}} : i16 to f64 + // CHECK: %[[N16:.*]] = llvm.sext %{{.*}}: i16 to i32 + // CHECK: %[[RESULT_16:.*]] = llvm.call @__ocml_pown_f64(%[[F16]], %[[N16]]) : (f64, i32) -> f64 + // CHECK: llvm.fptosi %[[RESULT_16]] : f64 to i16 + %0 = math.ipowi %arg16, %arg16_1 : i16 + // CHECK: %[[F32:.*]] = llvm.sitofp %arg2 : i32 to f64 + // CHECK: %[[RESULT_32:.*]] = llvm.call @__ocml_pown_f64(%[[F32]], %{{.*}}) : (f64, i32) -> f64 + // CHECK: llvm.fptosi %[[RESULT_32]] : f64 to i32 + %1 = math.ipowi %arg32, %arg32_1 : i32 + return %0, %1 : i16, i32 + } +} + +// ----- + // Math operation not inside function // Ensure it not crash