diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h index 46573e7966ccc..770f257d89bd5 100644 --- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h +++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h @@ -9,6 +9,7 @@ #define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/IR/PatternMatch.h" #include @@ -20,7 +21,8 @@ class Pass; /// Populate the given list with patterns that convert from Math to ROCDL calls. void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + amdgpu::Chipset chipset); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 25e9d34f3e653..a2eb335faac6c 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -778,6 +778,10 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { let summary = "Convert Math dialect to ROCDL library calls"; let description = [{ This pass converts supported Math ops to ROCDL library calls. + + The chipset option specifies the target AMDGPU architecture. If the chipset + is empty, none of the chipset-dependent patterns are added and the pass + will not attempt to parse the chipset. }]; let dependentDialects = [ "arith::ArithDialect", @@ -785,6 +789,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { "ROCDL::ROCDLDialect", "vector::VectorDialect", ]; + let options = [Option<"chipset", "chipset", "std::string", + /*default=*/"\"gfx000\"", + "Chipset that these operations will run on">]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index b215211e131d4..c03f3a5d3889c 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns( GPUSubgroupBroadcastOpToROCDL>(converter); patterns.add(converter, chipset); - populateMathToROCDLConversionPatterns(converter, patterns); + populateMathToROCDLConversionPatterns(converter, patterns, chipset); } diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index df219f3ff4f6e..4ba7eab64a785 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -10,6 +10,8 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -42,8 +44,65 @@ static void populateOpPatterns(const LLVMTypeConverter &converter, f32ApproxFunc, f16Func); } +struct ClampFOpConversion final + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + ClampFOpConversion(const LLVMTypeConverter &converter, + amdgpu::Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + + LogicalResult + matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only f16 and f32 types are supported by fmed3 + Type opTy = op.getType(); + auto resultType = getTypeConverter()->convertType(opTy); + + if (auto vectorType = dyn_cast(opTy)) { + opTy = vectorType.getElementType(); + } + + if (!isa(opTy)) { + return rewriter.notifyMatchFailure( + op, "fmed3 only supports f16 and f32 types"); + } + + // Handle multi-dimensional vectors (converted to LLVM arrays) + if (auto arrayType = dyn_cast(resultType)) { + // Handle multi-dimensional vectors (converted to LLVM arrays) + return LLVM::detail::handleMultidimensionalVectors( + op.getOperation(), adaptor.getOperands(), *getTypeConverter(), + [&](Type llvm1DVectorTy, ValueRange operands) -> Value { + typename math::ClampFOp::Adaptor adaptor(operands); + return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getValue(), adaptor.getMin(), + adaptor.getMax()); + }, + rewriter); + } + + // Handle 1D vectors and scalars directly + rewriter.replaceOpWithNewOp(op, op.getType(), op.getValue(), + op.getMin(), op.getMax()); + return success(); + } + + amdgpu::Chipset chipset; +}; + +static void addChipsetDependentPatterns(const LLVMTypeConverter &converter, + RewritePatternSet &patterns, + amdgpu::Chipset chipset) { + + // V_MED3_F16/F32 only exists in gfx9+ architectures + if (chipset.majorVersion >= 9) { + patterns.add(converter, chipset); + } +} + void mlir::populateMathToROCDLConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns) { + const LLVMTypeConverter &converter, RewritePatternSet &patterns, + amdgpu::Chipset chipset) { // Handled by mathToLLVM: math::AbsIOp // Handled by mathToLLVM: math::AbsFOp // Handled by mathToLLVM: math::CopySignOp @@ -118,15 +177,17 @@ void mlir::populateMathToROCDLConversionPatterns( // worth creating a separate pass for it. populateOpPatterns(converter, patterns, "__ocml_fmod_f32", "__ocml_fmod_f64", "__ocml_fmod_f16"); + + addChipsetDependentPatterns(converter, patterns, chipset); } -namespace { -struct ConvertMathToROCDLPass - : public impl::ConvertMathToROCDLBase { - ConvertMathToROCDLPass() = default; +struct ConvertMathToROCDLPass final + : impl::ConvertMathToROCDLBase { + using impl::ConvertMathToROCDLBase< + ConvertMathToROCDLPass>::ConvertMathToROCDLBase; + void runOnOperation() override; }; -} // namespace void ConvertMathToROCDLPass::runOnOperation() { auto m = getOperation(); @@ -135,10 +196,20 @@ void ConvertMathToROCDLPass::runOnOperation() { RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(ctx, DataLayout(m)); LLVMTypeConverter converter(ctx, options); - populateMathToROCDLConversionPatterns(converter, patterns); + + // Only populate chipset-dependent patterns if chipset is specified + if (!chipset.empty()) { + FailureOr maybeChipset = amdgpu::Chipset::parse(chipset); + if (failed(maybeChipset)) { + return signalPassFailure(); + } + populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset); + } + ConversionTarget target(getContext()); - target.addLegalDialect(); + target + .addLegalDialect(); target.addIllegalOp f16 @@ -596,3 +597,76 @@ module @test_module { func.return %result : vector<2x2xf16> } } + +// ----- + +// f16 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_f16 +func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 { + %r = math.clampf %x to [%lo, %hi] : f16 + return %r : f16 + // POST9: rocdl.fmed3 {{.*}} : f16 + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : f16 +} + +// f32 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_f32 +func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 { + %r = math.clampf %x to [%lo, %hi] : f32 + return %r : f32 + // POST9: rocdl.fmed3 {{.*}} : f32 + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : f32 +} + +// ----- + +// Vector f16 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_vector_f16 +func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> { + %r = math.clampf %x to [%lo, %hi] : vector<2xf16> + return %r : vector<2xf16> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : vector<2xf16> +} + +// ----- + +// Vector f32 clamp → rocdl.fmed3 on gfx9+ +// CHECK-LABEL: func.func @clampf_vector_f32 +func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> { + %r = math.clampf %x to [%lo, %hi] : vector<2xf32> + return %r : vector<2xf32> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf32> + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : vector<2xf32> +} + +// ----- + +// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors) +// CHECK-LABEL: func.func @clampf_vector_2d_f16 +func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> { + %r = math.clampf %x to [%lo, %hi] : vector<2x2xf16> + return %r : vector<2x2xf16> + // POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>> + // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> + // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> + // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> + // PRE9-NOT: rocdl.fmed3 + // PRE9: math.clampf {{.*}} : vector<2x2xf16> +} + +// ----- +// CHECK-LABEL: func.func @clampf_bf16 +func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 { + %r = math.clampf %x to [%lo, %hi] : bf16 + return %r : bf16 + // CHECK: math.clampf {{.*}} : bf16 + // CHECK-NOT: rocdl.fmed3 +}