Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVM][NVPTX]: Add intrinsic for setmaxnreg #77289

Merged
merged 1 commit into from Jan 9, 2024

Conversation

durga4github
Copy link
Contributor

@durga4github durga4github commented Jan 8, 2024

This patch adds an intrinsic for setmaxnreg PTX instruction.

  • PTX Doc link for this instruction: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg

  • The i32 argument, an immediate value, specifies the actual
    absolute register count for the instruction.

  • The setmaxnreg instruction is available in SM90a.
    So, this patch adds 'hasSM90a' predicate to use in
    the NVPTX backend.

  • lit tests are added to verify the lowering of the intrinsic.

  • Verifier logic (and tests) are added to test the register
    count range and divisibility-by-8 requirements.

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 8, 2024

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-support

Author: Durgadoss R (durga4github)

Changes

This patch adds an intrinsic for setmaxnreg PTX instruction.

  • PTX Doc link for this instruction: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg

  • The first argument, an i32 flags, is a compile-time constant, indicating the inc(flag=0)/dec(flag=1) modifiers.

  • The second argument, an immediate value, specifies the actual absolute register count for the instruction.

  • The setmaxnreg instruction is available in SM90a. So, this patch adds 'hasSM90a' predicate to use in the NVPTX backend.

  • lit tests are added to verify the lowering of the intrinsic.

The modifiers are encoded into flags so that the same intrinsic can be extended with more options in future. (without having to add separate intrinsics).

The flags are defined in Support/NVVMIntrinsicFlags.h, to facilitate usage by both upstream and downstream clients.


Full diff: https://github.com/llvm/llvm-project/pull/77289.diff

7 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+6)
  • (added) llvm/include/llvm/Support/NVVMIntrinsicFlags.h (+39)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+20)
  • (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+4)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+13)
  • (added) llvm/test/CodeGen/NVPTX/setmaxnreg.ll (+15)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 6fd8e80013cee5..81c56ca3c6ee03 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -4710,4 +4710,10 @@ def int_nvvm_is_explicit_cluster
               [IntrNoMem, IntrSpeculatable, NoUndef<RetIndex>],
               "llvm.nvvm.is_explicit_cluster">;
 
+// Setmaxnreg intrinsic
+def int_nvvm_setmaxnreg_sync_aligned_u32
+  : DefaultAttrsIntrinsic<[], [llvm_i32_ty, llvm_i32_ty],
+              [IntrConvergent, IntrNoMem, IntrHasSideEffects, ImmArg<ArgIndex<1>>],
+              "llvm.nvvm.setmaxnreg.sync.aligned.u32">;
+
 } // let TargetPrefix = "nvvm"
diff --git a/llvm/include/llvm/Support/NVVMIntrinsicFlags.h b/llvm/include/llvm/Support/NVVMIntrinsicFlags.h
new file mode 100644
index 00000000000000..23c265831ae4e2
--- /dev/null
+++ b/llvm/include/llvm/Support/NVVMIntrinsicFlags.h
@@ -0,0 +1,39 @@
+//===--- NVVMIntrinsicFlags.h -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// This file contains the definitions of the enumerations and flags
+/// associated with NVVM Intrinsics.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
+#define LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
+
+#include <stdint.h>
+
+namespace llvm {
+namespace nvvm {
+
+enum SetMaxNRegAction {
+  ACTION_INC = 0,
+  ACTION_DEC = 1,
+};
+
+typedef union {
+  uint32_t V;
+  struct {
+    uint32_t Action : 1;    // inc(0) or dec(1)
+    uint32_t reserved : 31;
+  } U;
+} SetMaxNRegFlags;
+
+} // namespace nvvm
+} // namespace llvm
+
+#endif // LLVM_SUPPORT_NVVMINTRINSICFLAGS_H
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index b7a20c351f5ff6..c5e575e805ab93 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -20,6 +20,7 @@
 #include "llvm/MC/MCSymbol.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FormattedStream.h"
+#include "llvm/Support/NVVMIntrinsicFlags.h"
 #include <cctype>
 using namespace llvm;
 
@@ -340,3 +341,22 @@ void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
     break;
   }
 }
+
+void NVPTXInstPrinter::printSetMaxNRegActionFlag(const MCInst *MI, int OpNum,
+                                                 raw_ostream &O,
+                                                 const char *Modifier) {
+  nvvm::SetMaxNRegFlags Flags;
+  Flags.V = (int)MI->getOperand(OpNum).getImm();
+
+  using Action = nvvm::SetMaxNRegAction;
+  switch (Flags.U.Action) {
+    case Action::ACTION_INC:
+      O << ".inc";
+      break;
+    case Action::ACTION_DEC:
+      O << ".dec";
+      break;
+    default:
+      llvm_unreachable("Invalid action flag for setmaxnreg intrinsic");
+  }
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index e6954f861cd10e..234d5f139ad496 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -49,6 +49,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
                        raw_ostream &O, const char *Modifier = nullptr);
   void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
                      const char *Modifier = nullptr);
+  void printSetMaxNRegActionFlag(const MCInst *MI, int OpNum, raw_ostream &O,
+                                 const char *Modifier = nullptr);
 };
 
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 13665985f52eba..aa02814061da4f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -164,6 +164,10 @@ def True : Predicate<"true">;
 class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
 class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
 
+// Explicit records for arch-accelerated SM versions
+def hasSM90a : Predicate<"Subtarget->getSmVersion() == 90"
+                          "&& Subtarget->getFullSmVersion() == 901">;
+
 // non-sync shfl instructions are not available on sm_70+ in PTX6.4+
 def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
                           "&& Subtarget->getPTXVersion() >= 64)">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 85eae44f349aa3..9594d47fedf5ee 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6727,3 +6727,16 @@ def is_explicit_cluster: NVPTXInst<(outs Int1Regs:$d), (ins),
               "mov.pred\t$d, %is_explicit_cluster;",
               [(set Int1Regs:$d, (int_nvvm_is_explicit_cluster))]>,
     Requires<[hasSM<90>, hasPTX<78>]>;
+
+// setmaxnreg intrinsic
+def SetMaxNRegFlags : Operand<i32> {
+  let PrintMethod = "printSetMaxNRegActionFlag";
+}
+
+let isConvergent = true in {
+def INT_SET_MAXNREG : NVPTXInst<(outs),
+    (ins SetMaxNRegFlags:$flags, i32imm:$reg_count),
+    "setmaxnreg${flags:action}.sync.aligned.u32 $reg_count;",
+    [(int_nvvm_setmaxnreg_sync_aligned_u32 imm:$flags, timm:$reg_count)]>,
+    Requires<[hasSM90a, hasPTX<80>]>;
+} // isConvergent
diff --git a/llvm/test/CodeGen/NVPTX/setmaxnreg.ll b/llvm/test/CodeGen/NVPTX/setmaxnreg.ll
new file mode 100644
index 00000000000000..25698088e1cf7b
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/setmaxnreg.ll
@@ -0,0 +1,15 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90a -mattr=+ptx80| FileCheck --check-prefixes=CHECK %s
+; RUN: %if ptxas-12.0 %{ llc < %s -march=nvptx64 -mcpu=sm_90a -mattr=+ptx80| %ptxas-verify -arch=sm_90a %}
+
+declare void @llvm.nvvm.setmaxnreg.sync.aligned.u32(i32 %flags, i32 %reg_count)
+
+; CHECK-LABEL: test_set_maxn_reg
+define void @test_set_maxn_reg() {
+  ; CHECK: setmaxnreg.inc.sync.aligned.u32 96;
+  call void @llvm.nvvm.setmaxnreg.sync.aligned.u32(i32 0, i32 96)
+
+  ; CHECK: setmaxnreg.dec.sync.aligned.u32 64;
+  call void @llvm.nvvm.setmaxnreg.sync.aligned.u32(i32 1, i32 64)
+
+  ret void
+}

Copy link

github-actions bot commented Jan 8, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@durga4github
Copy link
Contributor Author

durga4github commented Jan 8, 2024

@grypp , @joker-eph , @jholewinski Could you please help with the review?

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me in general.

PTX doc has restriction for this instruction. Is it possible to check that?

Operand imm-reg-count is an integer constant. The value of imm-reg-count must be in the range 24 to 256 (both inclusive) and must be a multiple of 8.

llvm/include/llvm/IR/IntrinsicsNVVM.td Outdated Show resolved Hide resolved
@grypp grypp requested a review from Artem-B January 8, 2024 17:09
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td Outdated Show resolved Hide resolved
llvm/test/CodeGen/NVPTX/setmaxnreg.ll Outdated Show resolved Hide resolved
Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing the comments. It looks good to me, but let's wait for others to review it.

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a small nit.

llvm/lib/IR/Verifier.cpp Outdated Show resolved Hide resolved
This patch adds an intrinsic for setmaxnreg instruction.
* PTX Doc link for this instruction:
  https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg

* The i32 argument, an immediate value, specifies the actual
  absolute register count for the instruction.
* The `setmaxnreg` instruction is available in SM90a.
  So, this patch adds 'hasSM90a' predicate to use in
  the NVPTX backend.
* lit tests are added to verify the lowering of the intrinsic.
* Verifier logic (and tests) are added to test the register
  count range and divisibility-by-8 requirements.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
@durga4github
Copy link
Contributor Author

@Artem-B , I don't have commit access. Please help merge the commit.

@Artem-B Artem-B merged commit 340cc17 into llvm:main Jan 9, 2024
3 of 4 checks passed
@durga4github durga4github deleted the durgadossr/setmaxnreg branch January 10, 2024 09:02
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
This patch adds an intrinsic for setmaxnreg PTX instruction.
* PTX Doc link for this instruction:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg

* The i32 argument, an immediate value, specifies the actual
  absolute register count for the instruction.
* The `setmaxnreg` instruction is available in SM90a.
  So, this patch adds 'hasSM90a' predicate to use in
  the NVPTX backend.
* lit tests are added to verify the lowering of the intrinsic.
* Verifier logic (and tests) are added to test the register
  count range and divisibility-by-8 requirements.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants