diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 80959ffbaf426..809b405b704bc 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -717,6 +717,26 @@ def AMDGPU_SchedBarrierOp : }]; } +def AMDGPU_WaitcntOp : + AMDGPU_Op<"waitcnt">, + Arguments<(ins + OptionalAttr:$vmcnt, + OptionalAttr:$expcnt, + OptionalAttr:$lgkmcnt + )> + { + let summary = "Wrapper on ROCDL SWaitcntOp"; + let description = [{ + Covenience wrapper on `rocdl.s.waitcnt`. Hides the architecture specific + bitpacking from user. Missing values will be assumed maximum values supported + by the architecture. Large values will also be clamped to the maximum + supported values. + }]; + let assemblyFormat = [{ + oilist( `vmcnt` `(` $vmcnt `)` | `expcnt` `(` $expcnt `)` | `lgkmcnt` `(` $lgkmcnt `)` ) attr-dict + }]; +} + def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB", "The possible permutations of the lanes storing B available in an MFMA", [ diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index ef35ee208f002..1940ef8775688 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -419,6 +419,82 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { } }; +// TODO: AMDGPU backend already have all this bitpacking logic, we should move +// it to some common place. +/// \details \p Vmcnt, \p Expcnt and \p Lgkmcnt are decoded as follows: +/// \p Vmcnt = \p Waitcnt[3:0] (pre-gfx9) +/// \p Vmcnt = \p Waitcnt[15:14,3:0] (gfx9,10) +/// \p Vmcnt = \p Waitcnt[15:10] (gfx11) +/// \p Expcnt = \p Waitcnt[6:4] (pre-gfx11) +/// \p Expcnt = \p Waitcnt[2:0] (gfx11) +/// \p Lgkmcnt = \p Waitcnt[11:8] (pre-gfx10) +/// \p Lgkmcnt = \p Waitcnt[13:8] (gfx10) +/// \p Lgkmcnt = \p Waitcnt[9:4] (gfx11) +static FailureOr encodeWaitcnt(Chipset chipset, unsigned vmcnt, + unsigned expcnt, unsigned lgkmcnt) { + if (chipset.majorVersion < 9) { + vmcnt = std::min(15u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(15u, lgkmcnt); + return vmcnt | (expcnt << 4) | (lgkmcnt << 8); + } + if (chipset.majorVersion == 9) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(15u, lgkmcnt); + unsigned lowBits = vmcnt & 0xF; + unsigned highBits = (vmcnt >> 4) << 14; + unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); + return lowBits | highBits | otherCnts; + } + if (chipset.majorVersion == 10) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(63u, lgkmcnt); + unsigned lowBits = vmcnt & 0xF; + unsigned highBits = (vmcnt >> 4) << 14; + unsigned otherCnts = (expcnt << 4) | (lgkmcnt << 8); + return lowBits | highBits | otherCnts; + } + if (chipset.majorVersion == 11) { + vmcnt = std::min(63u, vmcnt); + expcnt = std::min(7u, expcnt); + lgkmcnt = std::min(63u, lgkmcnt); + return (vmcnt << 10) | expcnt | (lgkmcnt << 4); + } + return failure(); +} + +struct WaitcntOpLowering : public ConvertOpToLLVMPattern { + WaitcntOpLowering(const LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + + Chipset chipset; + + LogicalResult + matchAndRewrite(WaitcntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto getVal = [](Attribute attr) -> unsigned { + if (attr) + return cast(attr).getInt(); + + // This value will be clamped to the maximum value for the chipset. + return 1024 * 1024; + }; + unsigned vmcnt = getVal(adaptor.getVmcntAttr()); + unsigned expcnt = getVal(adaptor.getExpcntAttr()); + unsigned lgkmcnt = getVal(adaptor.getLgkmcntAttr()); + + FailureOr waitcnt = + encodeWaitcnt(chipset, vmcnt, expcnt, lgkmcnt); + if (failed(waitcnt)) + return op.emitOpError("unsupported chipset"); + + rewriter.replaceOpWithNewOp(op, *waitcnt); + return success(); + } +}; + struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern { LDSBarrierOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} @@ -1825,9 +1901,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, ROCDL::RawPtrBufferAtomicUminOp>, RawBufferOpLowering, - AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, - MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering, - ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, + AMDGPUDPPLowering, WaitcntOpLowering, LDSBarrierOpLowering, + SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, + WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, TransposeLoadOpLowering>(converter, chipset); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/waitcnt.mlir b/mlir/test/Conversion/AMDGPUToROCDL/waitcnt.mlir new file mode 100644 index 0000000000000..71617df05eb60 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/waitcnt.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s --check-prefixes=CHECK,GFX9 +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1030 | FileCheck %s --check-prefixes=CHECK,GFX10 +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s --check-prefixes=CHECK,GFX11 + + +// CHECK-LABEL: func @waitcnt +func.func @waitcnt() { + // GFX9: rocdl.s.waitcnt 53119 + // GFX10: rocdl.s.waitcnt 65407 + // GFX11: rocdl.s.waitcnt 65527 + amdgpu.waitcnt + + // GFX9: rocdl.s.waitcnt 3952 + // GFX10: rocdl.s.waitcnt 16240 + // GFX11: rocdl.s.waitcnt 1015 + amdgpu.waitcnt vmcnt(0) + + // GFX9: rocdl.s.waitcnt 53007 + // GFX10: rocdl.s.waitcnt 65295 + // GFX11: rocdl.s.waitcnt 65520 + amdgpu.waitcnt expcnt(0) + + // GFX9: rocdl.s.waitcnt 49279 + // GFX10: rocdl.s.waitcnt 49279 + // GFX11: rocdl.s.waitcnt 64519 + amdgpu.waitcnt lgkmcnt(0) + + return +} diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 5559ac8f1a5c3..3583e0e291d36 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -504,3 +504,18 @@ func.func @gather_to_lds(%idx1 : index, %idx2 : index, %mem1 : memref<32xf16>, % amdgpu.gather_to_lds %mem1[%idx1], %smem2[%idx1, %idx2] : vector<2xf16>, memref<32xf16>, memref<32x32xf16, #gpu.address_space> func.return } + +// CHECK-LABEL: func @waitcnt +func.func @waitcnt() { + // CHECK: amdgpu.waitcnt vmcnt(1) expcnt(2) lgkmcnt(3) + // CHECK: amdgpu.waitcnt vmcnt(3) expcnt(2) lgkmcnt(1) + // CHECK: amdgpu.waitcnt vmcnt(1) + // CHECK: amdgpu.waitcnt expcnt(2) + // CHECK: amdgpu.waitcnt lgkmcnt(3) + amdgpu.waitcnt vmcnt(1) expcnt(2) lgkmcnt(3) + amdgpu.waitcnt lgkmcnt(1) expcnt(2) vmcnt(3) + amdgpu.waitcnt vmcnt(1) + amdgpu.waitcnt expcnt(2) + amdgpu.waitcnt lgkmcnt(3) + func.return +}