diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 1a37d057776e2..060c7183fcb3a 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -785,6 +785,12 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { "ROCDL::ROCDLDialect", "vector::VectorDialect", ]; + let options = [ + Option<"chipset", "chipset", "std::string", + /*default=*/"\"gfx000\"", + "Chipset that these operations will run on"> + ]; + } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index df219f3ff4f6e..6da7e9a850ef7 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -10,6 +10,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -120,25 +121,51 @@ void mlir::populateMathToROCDLConversionPatterns( "__ocml_fmod_f64", "__ocml_fmod_f16"); } -namespace { -struct ConvertMathToROCDLPass - : public impl::ConvertMathToROCDLBase { - ConvertMathToROCDLPass() = default; +struct ClampFOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + ClampFOpConversion(const LLVMTypeConverter &converter, + amdgpu::Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + + LogicalResult + matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // V_MED3_F16/F32 only exists in gfx9+ artchitectures + if (chipset.majorVersion < 9) { + std::string msg = + ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) + + "): V_MED_F16 / V_MED3_F32 not supported."); + return rewriter.notifyMatchFailure(op, msg); + } + rewriter.replaceOpWithNewOp(op, op.getType(), op.getValue(), + op.getMin(), op.getMax()); + return success(); + } + amdgpu::Chipset chipset; +}; + +struct ConvertMathToROCDLPass final + : impl::ConvertMathToROCDLBase { + using impl::ConvertMathToROCDLBase< + ConvertMathToROCDLPass>::ConvertMathToROCDLBase; + void runOnOperation() override; }; -} // namespace void ConvertMathToROCDLPass::runOnOperation() { auto m = getOperation(); MLIRContext *ctx = m.getContext(); + FailureOr maybeChipset = amdgpu::Chipset::parse(chipset); RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(ctx, DataLayout(m)); LLVMTypeConverter converter(ctx, options); + patterns.add(converter, *maybeChipset); populateMathToROCDLConversionPatterns(converter, patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); + target + .addLegalDialect(); target.addIllegalOp f16 @@ -596,3 +597,33 @@ module @test_module { func.return %result : vector<2x2xf16> } } + +// ----- + +// f16 clamp → rocdl.fmed3 on gfx9+ +func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 { + %r = math.clampf %x to [%lo, %hi] : f16 + return %r : f16 +} + +// f32 clamp → rocdl.fmed3 on gfx9+ +func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 { + %r = math.clampf %x to [%lo, %hi] : f32 + return %r : f32 +} + +// POST9-LABEL: func.func @clampf_f16 +// POST9: rocdl.fmed3 {{.*}} : f16 +// POST9: return + +// POST9-LABEL: func.func @clampf_f32 +// POST9: rocdl.fmed3 {{.*}} : f32 +// POST9: return + +// PRE9-LABEL: func.func @clampf_f16 +// PRE9-NOT: rocdl.fmed3 +// PRE9: math.clampf {{.*}} : f16 + +// PRE9-LABEL: func.func @clampf_f32 +// PRE9-NOT: rocdl.fmed3 +// PRE9: math.clampf {{.*}} : f32