Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,12 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
"ROCDL::ROCDLDialect",
"vector::VectorDialect",
];
let options = [
Option<"chipset", "chipset", "std::string",
/*default=*/"\"gfx000\"",
"Chipset that these operations will run on">
];

}

//===----------------------------------------------------------------------===//
Expand Down
41 changes: 34 additions & 7 deletions mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
Expand Down Expand Up @@ -120,25 +121,51 @@ void mlir::populateMathToROCDLConversionPatterns(
"__ocml_fmod_f64", "__ocml_fmod_f16");
}

namespace {
struct ConvertMathToROCDLPass
: public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
ConvertMathToROCDLPass() = default;
struct ClampFOpConversion : public ConvertOpToLLVMPattern<math::ClampFOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
ClampFOpConversion(const LLVMTypeConverter &converter,
amdgpu::Chipset chipset)
: ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}

LogicalResult
matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// V_MED3_F16/F32 only exists in gfx9+ artchitectures
if (chipset.majorVersion < 9) {
std::string msg =
("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) +
"): V_MED_F16 / V_MED3_F32 not supported.");
return rewriter.notifyMatchFailure(op, msg);
}
rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
op.getMin(), op.getMax());
return success();
}
amdgpu::Chipset chipset;
};

struct ConvertMathToROCDLPass final
: impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
using impl::ConvertMathToROCDLBase<
ConvertMathToROCDLPass>::ConvertMathToROCDLBase;

void runOnOperation() override;
};
} // namespace

void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
MLIRContext *ctx = m.getContext();
FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);

RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
patterns.add<ClampFOpConversion>(converter, *maybeChipset);
populateMathToROCDLConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
target
.addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
Expand Down
33 changes: 32 additions & 1 deletion mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9

module @test_module {
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
Expand Down Expand Up @@ -596,3 +597,33 @@ module @test_module {
func.return %result : vector<2x2xf16>
}
}

// -----

// f16 clamp → rocdl.fmed3 on gfx9+
func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
%r = math.clampf %x to [%lo, %hi] : f16
return %r : f16
}

// f32 clamp → rocdl.fmed3 on gfx9+
func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
%r = math.clampf %x to [%lo, %hi] : f32
return %r : f32
}

// POST9-LABEL: func.func @clampf_f16
// POST9: rocdl.fmed3 {{.*}} : f16
// POST9: return

// POST9-LABEL: func.func @clampf_f32
// POST9: rocdl.fmed3 {{.*}} : f32
// POST9: return

// PRE9-LABEL: func.func @clampf_f16
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : f16

// PRE9-LABEL: func.func @clampf_f32
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : f32