From 2761997ffff3355d4e0fd392603c51fde6405380 Mon Sep 17 00:00:00 2001 From: keshavvinayak01 Date: Mon, 22 Sep 2025 13:17:34 +0000 Subject: [PATCH 1/2] Added arith.clampf -> rocdl.fmed3 conversion{ Signed-off-by: keshavvinayak01 --- mlir/include/mlir/Conversion/Passes.td | 6 +++ .../Conversion/MathToROCDL/MathToROCDL.cpp | 41 +++++++++++++++---- .../Conversion/MathToROCDL/math-to-rocdl.mlir | 33 ++++++++++++++- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 1a37d057776e2..060c7183fcb3a 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -785,6 +785,12 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { "ROCDL::ROCDLDialect", "vector::VectorDialect", ]; + let options = [ + Option<"chipset", "chipset", "std::string", + /*default=*/"\"gfx000\"", + "Chipset that these operations will run on"> + ]; + } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index df219f3ff4f6e..6da7e9a850ef7 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -120,25 +121,51 @@ void mlir::populateMathToROCDLConversionPatterns( "__ocml_fmod_f64", "__ocml_fmod_f16"); } -namespace { -struct ConvertMathToROCDLPass - : public impl::ConvertMathToROCDLBase { - ConvertMathToROCDLPass() = default; +struct ClampFOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + ClampFOpConversion(const LLVMTypeConverter &converter, + amdgpu::Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + + LogicalResult + matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // V_MED3_F16/F32 only exists in gfx9+ artchitectures + if (chipset.majorVersion < 9) { + std::string msg = + ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) + + "): V_MED_F16 / V_MED3_F32 not supported."); + return rewriter.notifyMatchFailure(op, msg); + } + rewriter.replaceOpWithNewOp(op, op.getType(), op.getValue(), + op.getMin(), op.getMax()); + return success(); + } + amdgpu::Chipset chipset; +}; + +struct ConvertMathToROCDLPass final + : impl::ConvertMathToROCDLBase { + using impl::ConvertMathToROCDLBase< + ConvertMathToROCDLPass>::ConvertMathToROCDLBase; + void runOnOperation() override; }; -} // namespace void ConvertMathToROCDLPass::runOnOperation() { auto m = getOperation(); MLIRContext *ctx = m.getContext(); + FailureOr maybeChipset = amdgpu::Chipset::parse(chipset); RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(ctx, DataLayout(m)); LLVMTypeConverter converter(ctx, options); + patterns.add(converter, *maybeChipset); populateMathToROCDLConversionPatterns(converter, patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); + target + .addLegalDialect(); target.addIllegalOp f16 @@ -596,3 +597,33 @@ module @test_module { func.return %result : vector<2x2xf16> } } + +// ----- + +// f16 clamp → rocdl.fmed3 on gfx9+ +func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 { + %r = math.clampf %x to [%lo, %hi] : f16 + return %r : f16 +} + +// f32 clamp → rocdl.fmed3 on gfx9+ +func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 { + %r = math.clampf %x to [%lo, %hi] : f32 + return %r : f32 +} + +// POST9-LABEL: func.func @clampf_f16 +// POST9: rocdl.fmed3 {{.*}} : f16 +// POST9: return + +// POST9-LABEL: func.func @clampf_f32 +// POST9: rocdl.fmed3 {{.*}} : f32 +// POST9: return + +// PRE9-LABEL: func.func @clampf_f16 +// PRE9-NOT: rocdl.fmed3 +// PRE9: math.clampf {{.*}} : f16 + +// PRE9-LABEL: func.func @clampf_f32 +// PRE9-NOT: rocdl.fmed3 +// PRE9: math.clampf {{.*}} : f32 \ No newline at end of file From da4645f37fe1382f590a27103b01abf9e16f03c5 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha <31160700+keshavvinayak01@users.noreply.github.com> Date: Mon, 22 Sep 2025 18:52:45 +0530 Subject: [PATCH 2/2] Update math-to-rocdl.mlir --- mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index 488133ad8bddc..541d8d53cac4c 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -626,4 +626,4 @@ func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 { // PRE9-LABEL: func.func @clampf_f32 // PRE9-NOT: rocdl.fmed3 -// PRE9: math.clampf {{.*}} : f32 \ No newline at end of file +// PRE9: math.clampf {{.*}} : f32