Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][nvvm] Introduce setmaxregister.sync.aligned Op #73780

Merged
merged 3 commits into from
Nov 29, 2023

Conversation

grypp
Copy link
Member

@grypp grypp commented Nov 29, 2023

This PR introduce setmaxregister.sync.aligned Op to increase or decrease the register size.

https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 29, 2023

@llvm/pr-subscribers-mlir

Author: Guray Ozen (grypp)

Changes

This PR introduce setmaxregister.sync.aligned Op to increase or decrease the register size.

https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg


Full diff: https://github.com/llvm/llvm-project/pull/73780.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+22)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+9-7)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 829fb68549307c8..cbe7c3919d62043 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -400,6 +400,28 @@ def NVVM_FenceScClusterOp : NVVM_Op<"fence.sc.cluster"> {
   let assemblyFormat = "attr-dict";
 }
 
+def SetMaxRegisterActionIncrease : I32EnumAttrCase<"increase", 0>;
+def SetMaxRegisterActionDecrease   : I32EnumAttrCase<"decrease", 1>;
+def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max register action",
+  [SetMaxRegisterActionDecrease, SetMaxRegisterActionIncrease]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def SetMaxRegisterActionAttr : EnumAttr<NVVM_Dialect, SetMaxRegisterAction, "action">;
+
+def NVVM_SetMaxRegisterOp : NVVM_PTXBuilder_Op<"setmaxregister.sync.aligned"> {
+  let arguments = (ins I32Attr:$count, SetMaxRegisterActionAttr:$action);
+  let assemblyFormat = "$action $count attr-dict";
+  let extraClassDefinition = [{        
+    std::string $cppClass::getPtx() {
+      if(getAction() == NVVM::SetMaxRegisterAction::increase)
+        return std::string("setmaxnreg.inc.sync.aligned.u32 %0;");
+      return std::string("setmaxnreg.dec.sync.aligned.u32 %0;");
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 def ShflKindBfly : I32EnumAttrCase<"bfly", 0>;
 def ShflKindUp   : I32EnumAttrCase<"up", 1>;
 def ShflKindDown : I32EnumAttrCase<"down", 2>;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 63ceebb08e5baa7..e2280c398153411 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -213,8 +213,7 @@ void MmaOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
 
   // Print the types of the operands and result.
-  p << " : "
-    << "(";
+  p << " : " << "(";
   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
                                              frags[1].regs[0].getType(),
                                              frags[2].regs[0].getType()},
@@ -954,9 +953,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   ss << "},";
   // Need to map read/write registers correctly.
   regCnt = (regCnt * 2);
-  ss << " $" << (regCnt) << ","
-     << " $" << (regCnt + 1) << ","
-     << " p";
+  ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
   if (!outputType.isInteger(32)) {
     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
   }
@@ -964,8 +961,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   if (isF16) {
     ss << ", $" << (regCnt + 5) << ",  $" << (regCnt + 6);
   }
-  ss << ";\n"
-     << "}\n";
+  ss << ";\n" << "}\n";
   ss.flush();
   return ptx;
 }
@@ -1007,6 +1003,12 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
   }
 }
 
+LogicalResult NVVM::SetMaxRegisterOp::verify() {
+  if (getCount() % 8)
+    return emitOpError("new register size must be multiple of 8");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 5fa907850cedf30..fe4c33854485cda 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -611,3 +611,13 @@ llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
   nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
   llvm.return
 }
+
+// -----
+
+func.func @set_max_register() {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.inc.sync.aligned.u32 [$0];", "n"
+  nvvm.setmaxregister.sync.aligned increase 232
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.dec.sync.aligned.u32 [$0];", "n"
+  nvvm.setmaxregister.sync.aligned decrease 40
+  func.return
+}

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 29, 2023

@llvm/pr-subscribers-mlir-llvm

Author: Guray Ozen (grypp)

Changes

This PR introduce setmaxregister.sync.aligned Op to increase or decrease the register size.

https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg


Full diff: https://github.com/llvm/llvm-project/pull/73780.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+22)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+9-7)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 829fb68549307c8..cbe7c3919d62043 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -400,6 +400,28 @@ def NVVM_FenceScClusterOp : NVVM_Op<"fence.sc.cluster"> {
   let assemblyFormat = "attr-dict";
 }
 
+def SetMaxRegisterActionIncrease : I32EnumAttrCase<"increase", 0>;
+def SetMaxRegisterActionDecrease   : I32EnumAttrCase<"decrease", 1>;
+def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max register action",
+  [SetMaxRegisterActionDecrease, SetMaxRegisterActionIncrease]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def SetMaxRegisterActionAttr : EnumAttr<NVVM_Dialect, SetMaxRegisterAction, "action">;
+
+def NVVM_SetMaxRegisterOp : NVVM_PTXBuilder_Op<"setmaxregister.sync.aligned"> {
+  let arguments = (ins I32Attr:$count, SetMaxRegisterActionAttr:$action);
+  let assemblyFormat = "$action $count attr-dict";
+  let extraClassDefinition = [{        
+    std::string $cppClass::getPtx() {
+      if(getAction() == NVVM::SetMaxRegisterAction::increase)
+        return std::string("setmaxnreg.inc.sync.aligned.u32 %0;");
+      return std::string("setmaxnreg.dec.sync.aligned.u32 %0;");
+    }
+  }];
+  let hasVerifier = 1;
+}
+
 def ShflKindBfly : I32EnumAttrCase<"bfly", 0>;
 def ShflKindUp   : I32EnumAttrCase<"up", 1>;
 def ShflKindDown : I32EnumAttrCase<"down", 2>;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 63ceebb08e5baa7..e2280c398153411 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -213,8 +213,7 @@ void MmaOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
 
   // Print the types of the operands and result.
-  p << " : "
-    << "(";
+  p << " : " << "(";
   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
                                              frags[1].regs[0].getType(),
                                              frags[2].regs[0].getType()},
@@ -954,9 +953,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   ss << "},";
   // Need to map read/write registers correctly.
   regCnt = (regCnt * 2);
-  ss << " $" << (regCnt) << ","
-     << " $" << (regCnt + 1) << ","
-     << " p";
+  ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
   if (!outputType.isInteger(32)) {
     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
   }
@@ -964,8 +961,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   if (isF16) {
     ss << ", $" << (regCnt + 5) << ",  $" << (regCnt + 6);
   }
-  ss << ";\n"
-     << "}\n";
+  ss << ";\n" << "}\n";
   ss.flush();
   return ptx;
 }
@@ -1007,6 +1003,12 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
   }
 }
 
+LogicalResult NVVM::SetMaxRegisterOp::verify() {
+  if (getCount() % 8)
+    return emitOpError("new register size must be multiple of 8");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 5fa907850cedf30..fe4c33854485cda 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -611,3 +611,13 @@ llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
   nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
   llvm.return
 }
+
+// -----
+
+func.func @set_max_register() {
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.inc.sync.aligned.u32 [$0];", "n"
+  nvvm.setmaxregister.sync.aligned increase 232
+  //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.dec.sync.aligned.u32 [$0];", "n"
+  nvvm.setmaxregister.sync.aligned decrease 40
+  func.return
+}

@grypp
Copy link
Member Author

grypp commented Nov 29, 2023

@durga4github would be great if you can review

Copy link

github-actions bot commented Nov 29, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

let arguments = (ins I32Attr:$regCount, SetMaxRegisterActionAttr:$action);
let assemblyFormat = "$action $regCount attr-dict";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify the below questions, for my understanding:

If we happen to add an LLVM intrinsic for this at a later point,
a)
Can we extend the same Op to lower to that instead of inline-PTX ?
(without any changes to the Op's interface)
b)
I believe we can achieve that by overriding the 'hasIntrinsic()' to return true (within the extraClassDeclaration).
Do we need anything more?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I think we should add LLVM intrinsic (this is "the right thing").

a) We can (should) keep the Op's interface same. So nvgpu->nvvm lowering remains unchanged (and other lowerings we have internally). We can use LLVM intrinsic when we have it.

Let's say we have LLVM intrinsic for this Op, what we need to is to add llvmBuilder part and remove getPtx.

def NVVM_SetMaxRegisterOp : NVVM_Op<"setmaxregister"> {
  let arguments = (ins I32Attr:$regCount, SetMaxRegisterActionAttr:$action);
  let assemblyFormat = "$action $regCount attr-dict";
  string llvmBuilder = [{
    if(getAction() == NVVM::SetMaxRegisterAction::increase)
      createIntrinsicCall(builder, llvm::Intrinsic::nvvm_set_max_register_increase);
    else createIntrinsicCall(builder,   llvm::Intrinsic::nvvm_set_max_register_descreate);
  }];
  let hasVerifier = 1;
}

b) I've added a hasIntrinsic() in BasicPtxBuilderOpInterface interface to use inline PTX and LLVM intrinsic together. If LLVM contains all necessary intrinsics for the Op, we don't need BasicPtxBuilderOpInterface and hasIntrinsic().

For example you can take a look at mbarrier.init Op

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation, Guray!
I am with you on a).

On b): Yes, I was looking at the mbarrier.init as an example, and wanted to confirm if both the intrinsic and the inline PTX forms can co-exist. It seems they can, without any issues.
(I was running the mbarrier.init Op example and was always generating the intrinsic version.
Later realized I was using mlir-translate and not "mlir-opt with conversion-to-llvm" to see the inline-PTX impl.)

Thanks for all the clarifications!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure happy to answer.

mlir-translate generates the llvm code. We use llvm's intrinsic here.

We have convert-nvvm-to-llvm mlir pass. It matches the OPs that implements BasicPtxBuilderOpInterface and generates inline ptx for them.

@grypp grypp merged commit 68433f6 into llvm:main Nov 29, 2023
3 checks passed
@grypp grypp deleted the setregister branch November 29, 2023 14:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants