Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/IR/PatternMatch.h"
#include <memory>

Expand All @@ -19,8 +20,11 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"

/// Populate the given list with patterns that convert from Math to ROCDL calls.
void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns);
// `chipset` specifies the AMDGPU chipset to target. If `std::nullopt`,
// none of the chipset dependent patterns are added.
void populateMathToROCDLConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
std::optional<amdgpu::Chipset> chipset);
} // namespace mlir

#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
7 changes: 7 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -778,13 +778,20 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
let summary = "Convert Math dialect to ROCDL library calls";
let description = [{
This pass converts supported Math ops to ROCDL library calls.

The chipset option specifies the target AMDGPU architecture. If the chipset
is empty, none of the chipset-dependent patterns are added, and the pass
will not attempt to parse the chipset.
}];
let dependentDialects = [
"arith::ArithDialect",
"func::FuncDialect",
"ROCDL::ROCDLDialect",
"vector::VectorDialect",
];
let options = [Option<"chipset", "chipset", "std::string",
/*default=*/"\"\"",
"Chipset that these operations will run on">];
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
GPUSubgroupBroadcastOpToROCDL>(converter);
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);

populateMathToROCDLConversionPatterns(converter, patterns);
populateMathToROCDLConversionPatterns(converter, patterns, chipset);
}
1 change: 1 addition & 0 deletions mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_conversion_library(MLIRMathToROCDL
Core

LINK_LIBS PUBLIC
MLIRAMDGPUUtils
MLIRDialectUtils
MLIRFuncDialect
MLIRGPUToGPURuntimeTransforms
Expand Down
76 changes: 67 additions & 9 deletions mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
Expand All @@ -19,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/DebugLog.h"

#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
Expand All @@ -42,8 +45,46 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
f32ApproxFunc, f16Func);
}

struct ClampFOpConversion final
: public ConvertOpToLLVMPattern<math::ClampFOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Only f16 and f32 types are supported by fmed3
Type opTy = op.getType();
Type resultType = getTypeConverter()->convertType(opTy);

if (auto vectorType = dyn_cast<VectorType>(opTy))
opTy = vectorType.getElementType();

if (!isa<Float16Type, Float32Type>(opTy))
return rewriter.notifyMatchFailure(
op, "fmed3 only supports f16 and f32 types");

// Handle multi-dimensional vectors (converted to LLVM arrays)
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType))
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) -> Value {
typename math::ClampFOp::Adaptor adaptor(operands);
return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
adaptor.getValue(), adaptor.getMin(),
adaptor.getMax());
},
rewriter);

// Handle 1D vectors and scalars directly
rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
op.getMin(), op.getMax());
return success();
}
};

void mlir::populateMathToROCDLConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
std::optional<amdgpu::Chipset> chipset) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFOp
// Handled by mathToLLVM: math::CopySignOp
Expand Down Expand Up @@ -118,15 +159,21 @@ void mlir::populateMathToROCDLConversionPatterns(
// worth creating a separate pass for it.
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
"__ocml_fmod_f64", "__ocml_fmod_f16");

if (chipset.has_value() && chipset->majorVersion >= 9) {
patterns.add<ClampFOpConversion>(converter);
} else {
LDBG() << "Chipset dependent patterns were not added";
}
}

namespace {
struct ConvertMathToROCDLPass
: public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
ConvertMathToROCDLPass() = default;
struct ConvertMathToROCDLPass final
: impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
using impl::ConvertMathToROCDLBase<
ConvertMathToROCDLPass>::ConvertMathToROCDLBase;

void runOnOperation() override;
};
} // namespace

void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
Expand All @@ -135,10 +182,21 @@ void ConvertMathToROCDLPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
populateMathToROCDLConversionPatterns(converter, patterns);

FailureOr<amdgpu::Chipset> maybeChipset;
if (!chipset.empty()) {
maybeChipset = amdgpu::Chipset::parse(chipset);
if (failed(maybeChipset))
return signalPassFailure();
}
populateMathToROCDLConversionPatterns(
converter, patterns,
succeeded(maybeChipset) ? std::optional(*maybeChipset) : std::nullopt);

ConversionTarget target(getContext());
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
vector::VectorDialect, LLVM::LLVMDialect>();
target
.addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
Expand Down
76 changes: 75 additions & 1 deletion mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9

module @test_module {
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
Expand Down Expand Up @@ -596,3 +597,76 @@ module @test_module {
func.return %result : vector<2x2xf16>
}
}

// -----

// f16 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_f16
func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
%r = math.clampf %x to [%lo, %hi] : f16
return %r : f16
// POST9: rocdl.fmed3 {{.*}} : f16
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : f16
}

// f32 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_f32
func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
%r = math.clampf %x to [%lo, %hi] : f32
return %r : f32
// POST9: rocdl.fmed3 {{.*}} : f32
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : f32
}

// -----

// Vector f16 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_vector_f16
func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> {
%r = math.clampf %x to [%lo, %hi] : vector<2xf16>
return %r : vector<2xf16>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : vector<2xf16>
}

// -----

// Vector f32 clamp → rocdl.fmed3 on gfx9+
// CHECK-LABEL: func.func @clampf_vector_f32
func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> {
%r = math.clampf %x to [%lo, %hi] : vector<2xf32>
return %r : vector<2xf32>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf32>
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : vector<2xf32>
}

// -----

// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors)
// CHECK-LABEL: func.func @clampf_vector_2d_f16
func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> {
%r = math.clampf %x to [%lo, %hi] : vector<2x2xf16>
return %r : vector<2x2xf16>
// POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
// POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
// POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
// POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
// PRE9-NOT: rocdl.fmed3
// PRE9: math.clampf {{.*}} : vector<2x2xf16>
}

// -----
// CHECK-LABEL: func.func @clampf_bf16
func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 {
%r = math.clampf %x to [%lo, %hi] : bf16
return %r : bf16
// CHECK: math.clampf {{.*}} : bf16
// CHECK-NOT: rocdl.fmed3
}