diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 30df3b739e5ca..fe4019f6d7d54 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2995,30 +2995,46 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async", // NVVM Griddepcontrol Ops //===----------------------------------------------------------------------===// -def NVVM_GriddepcontrolWaitOp : NVVM_IntrOp<"griddepcontrol.wait", [], 0> { - let assemblyFormat = "attr-dict"; +def GridDepActionWait : I32EnumCase<"wait", 0>; +def GridDepActionLaunchDependent : I32EnumCase<"launch_dependents", 1>; + +def GridDepActionKind : I32Enum<"GridDepActionKind", "Action kind for grid dependency control", + [GridDepActionWait, GridDepActionLaunchDependent]> { + let cppNamespace = "::mlir::NVVM"; +} +def GridDepActionAttr : EnumAttr; + +def NVVM_GriddepcontrolOp : NVVM_Op<"griddepcontrol", []> { let description = [{ - Causes the executing thread to wait until all prerequisite grids in flight + If the $kind attribute is set to `wait`, it causes the + executing thread to wait until all prerequisite grids in flight have completed and all the memory operations from the prerequisite grids are performed and made visible to the current grid. + When the $kind is launch_dependents, it signals that specific dependents + the runtime system designated to react to this instruction can be scheduled + as soon as all other CTAs in the grid issue the same instruction or have + completed. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol) }]; -} -def NVVM_GriddepcontrolLaunchDependentsOp - : NVVM_IntrOp<"griddepcontrol.launch.dependents", [], 0> { - let assemblyFormat = "attr-dict"; - - let description = [{ - Signals that specific dependents the runtime system designated to react to - this instruction can be scheduled as soon as all other CTAs in the grid - issue the same instruction or have completed. + let arguments = (ins GridDepActionAttr:$kind); + let assemblyFormat = "$kind attr-dict"; - [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-griddepcontrol) + string llvmBuilder = [{ + llvm::Intrinsic::ID id; + switch ($kind) { + case NVVM::GridDepActionKind::wait: + id = llvm::Intrinsic::nvvm_griddepcontrol_wait; + break; + case NVVM::GridDepActionKind::launch_dependents: + id = llvm::Intrinsic::nvvm_griddepcontrol_launch_dependents; + break; + } + createIntrinsicCall(builder, id); }]; } diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index c7fa41c98ac92..cd14be5473432 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -535,15 +535,15 @@ func.func @wgmma_wait_group_sync_aligned() { } func.func @griddepcontrol_wait() { - // CHECK: nvvm.griddepcontrol.wait - nvvm.griddepcontrol.wait + // CHECK: nvvm.griddepcontrol wait + nvvm.griddepcontrol wait return } func.func @griddepcontrol_launch_dependents() { - // CHECK: nvvm.griddepcontrol.launch.dependents - nvvm.griddepcontrol.launch.dependents + // CHECK: nvvm.griddepcontrol launch_dependents + nvvm.griddepcontrol launch_dependents return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 5c2cfa4683104..ff588268fc9d2 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -766,7 +766,7 @@ llvm.func @nvvm_wgmma_wait_group_aligned() { // CHECK-LABEL: @nvvm_griddepcontrol_wait llvm.func @nvvm_griddepcontrol_wait() { // CHECK: call void @llvm.nvvm.griddepcontrol.wait() - nvvm.griddepcontrol.wait + nvvm.griddepcontrol wait llvm.return } @@ -774,7 +774,7 @@ llvm.func @nvvm_griddepcontrol_wait() { // CHECK-LABEL: @nvvm_griddepcontrol_launch_dependents llvm.func @nvvm_griddepcontrol_launch_dependents() { // CHECK: call void @llvm.nvvm.griddepcontrol.launch.dependents() - nvvm.griddepcontrol.launch.dependents + nvvm.griddepcontrol launch_dependents llvm.return }