diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h index 770f257d89bd5..46573e7966ccc 100644 --- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h +++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h @@ -9,7 +9,6 @@ #define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/IR/PatternMatch.h" #include @@ -21,8 +20,7 @@ class Pass; /// Populate the given list with patterns that convert from Math to ROCDL calls. void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns, - amdgpu::Chipset chipset); + RewritePatternSet &patterns); } // namespace mlir #endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index a2eb335faac6c..25e9d34f3e653 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -778,10 +778,6 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { let summary = "Convert Math dialect to ROCDL library calls"; let description = [{ This pass converts supported Math ops to ROCDL library calls. - - The chipset option specifies the target AMDGPU architecture. If the chipset - is empty, none of the chipset-dependent patterns are added and the pass - will not attempt to parse the chipset. }]; let dependentDialects = [ "arith::ArithDialect", @@ -789,9 +785,6 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { "ROCDL::ROCDLDialect", "vector::VectorDialect", ]; - let options = [Option<"chipset", "chipset", "std::string", - /*default=*/"\"gfx000\"", - "Chipset that these operations will run on">]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index c03f3a5d3889c..b215211e131d4 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns( GPUSubgroupBroadcastOpToROCDL>(converter); patterns.add(converter, chipset); - populateMathToROCDLConversionPatterns(converter, patterns, chipset); + populateMathToROCDLConversionPatterns(converter, patterns); } diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 4ba7eab64a785..df219f3ff4f6e 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -10,8 +10,6 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" -#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -44,65 +42,8 @@ static void populateOpPatterns(const LLVMTypeConverter &converter, f32ApproxFunc, f16Func); } -struct ClampFOpConversion final - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - ClampFOpConversion(const LLVMTypeConverter &converter, - amdgpu::Chipset chipset) - : ConvertOpToLLVMPattern(converter), chipset(chipset) {} - - LogicalResult - matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Only f16 and f32 types are supported by fmed3 - Type opTy = op.getType(); - auto resultType = getTypeConverter()->convertType(opTy); - - if (auto vectorType = dyn_cast(opTy)) { - opTy = vectorType.getElementType(); - } - - if (!isa(opTy)) { - return rewriter.notifyMatchFailure( - op, "fmed3 only supports f16 and f32 types"); - } - - // Handle multi-dimensional vectors (converted to LLVM arrays) - if (auto arrayType = dyn_cast(resultType)) { - // Handle multi-dimensional vectors (converted to LLVM arrays) - return LLVM::detail::handleMultidimensionalVectors( - op.getOperation(), adaptor.getOperands(), *getTypeConverter(), - [&](Type llvm1DVectorTy, ValueRange operands) -> Value { - typename math::ClampFOp::Adaptor adaptor(operands); - return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy, - adaptor.getValue(), adaptor.getMin(), - adaptor.getMax()); - }, - rewriter); - } - - // Handle 1D vectors and scalars directly - rewriter.replaceOpWithNewOp(op, op.getType(), op.getValue(), - op.getMin(), op.getMax()); - return success(); - } - - amdgpu::Chipset chipset; -}; - -static void addChipsetDependentPatterns(const LLVMTypeConverter &converter, - RewritePatternSet &patterns, - amdgpu::Chipset chipset) { - - // V_MED3_F16/F32 only exists in gfx9+ architectures - if (chipset.majorVersion >= 9) { - patterns.add(converter, chipset); - } -} - void mlir::populateMathToROCDLConversionPatterns( - const LLVMTypeConverter &converter, RewritePatternSet &patterns, - amdgpu::Chipset chipset) { + const LLVMTypeConverter &converter, RewritePatternSet &patterns) { // Handled by mathToLLVM: math::AbsIOp // Handled by mathToLLVM: math::AbsFOp // Handled by mathToLLVM: math::CopySignOp @@ -177,17 +118,15 @@ void mlir::populateMathToROCDLConversionPatterns( // worth creating a separate pass for it. populateOpPatterns(converter, patterns, "__ocml_fmod_f32", "__ocml_fmod_f64", "__ocml_fmod_f16"); - - addChipsetDependentPatterns(converter, patterns, chipset); } -struct ConvertMathToROCDLPass final - : impl::ConvertMathToROCDLBase { - using impl::ConvertMathToROCDLBase< - ConvertMathToROCDLPass>::ConvertMathToROCDLBase; - +namespace { +struct ConvertMathToROCDLPass + : public impl::ConvertMathToROCDLBase { + ConvertMathToROCDLPass() = default; void runOnOperation() override; }; +} // namespace void ConvertMathToROCDLPass::runOnOperation() { auto m = getOperation(); @@ -196,20 +135,10 @@ void ConvertMathToROCDLPass::runOnOperation() { RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(ctx, DataLayout(m)); LLVMTypeConverter converter(ctx, options); - - // Only populate chipset-dependent patterns if chipset is specified - if (!chipset.empty()) { - FailureOr maybeChipset = amdgpu::Chipset::parse(chipset); - if (failed(maybeChipset)) { - return signalPassFailure(); - } - populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset); - } - + populateMathToROCDLConversionPatterns(converter, patterns); ConversionTarget target(getContext()); - target - .addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp f16 @@ -597,76 +596,3 @@ module @test_module { func.return %result : vector<2x2xf16> } } - -// ----- - -// f16 clamp → rocdl.fmed3 on gfx9+ -// CHECK-LABEL: func.func @clampf_f16 -func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 { - %r = math.clampf %x to [%lo, %hi] : f16 - return %r : f16 - // POST9: rocdl.fmed3 {{.*}} : f16 - // PRE9-NOT: rocdl.fmed3 - // PRE9: math.clampf {{.*}} : f16 -} - -// f32 clamp → rocdl.fmed3 on gfx9+ -// CHECK-LABEL: func.func @clampf_f32 -func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 { - %r = math.clampf %x to [%lo, %hi] : f32 - return %r : f32 - // POST9: rocdl.fmed3 {{.*}} : f32 - // PRE9-NOT: rocdl.fmed3 - // PRE9: math.clampf {{.*}} : f32 -} - -// ----- - -// Vector f16 clamp → rocdl.fmed3 on gfx9+ -// CHECK-LABEL: func.func @clampf_vector_f16 -func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> { - %r = math.clampf %x to [%lo, %hi] : vector<2xf16> - return %r : vector<2xf16> - // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> - // PRE9-NOT: rocdl.fmed3 - // PRE9: math.clampf {{.*}} : vector<2xf16> -} - -// ----- - -// Vector f32 clamp → rocdl.fmed3 on gfx9+ -// CHECK-LABEL: func.func @clampf_vector_f32 -func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> { - %r = math.clampf %x to [%lo, %hi] : vector<2xf32> - return %r : vector<2xf32> - // POST9: rocdl.fmed3 {{.*}} : vector<2xf32> - // PRE9-NOT: rocdl.fmed3 - // PRE9: math.clampf {{.*}} : vector<2xf32> -} - -// ----- - -// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors) -// CHECK-LABEL: func.func @clampf_vector_2d_f16 -func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> { - %r = math.clampf %x to [%lo, %hi] : vector<2x2xf16> - return %r : vector<2x2xf16> - // POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>> - // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>> - // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> - // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> - // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>> - // POST9: rocdl.fmed3 {{.*}} : vector<2xf16> - // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>> - // PRE9-NOT: rocdl.fmed3 - // PRE9: math.clampf {{.*}} : vector<2x2xf16> -} - -// ----- -// CHECK-LABEL: func.func @clampf_bf16 -func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 { - %r = math.clampf %x to [%lo, %hi] : bf16 - return %r : bf16 - // CHECK: math.clampf {{.*}} : bf16 - // CHECK-NOT: rocdl.fmed3 -}