-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][nvvm] Introduce setmaxregister.sync.aligned
Op
#73780
Conversation
@llvm/pr-subscribers-mlir Author: Guray Ozen (grypp) ChangesThis PR introduce Full diff: https://github.com/llvm/llvm-project/pull/73780.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 829fb68549307c8..cbe7c3919d62043 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -400,6 +400,28 @@ def NVVM_FenceScClusterOp : NVVM_Op<"fence.sc.cluster"> {
let assemblyFormat = "attr-dict";
}
+def SetMaxRegisterActionIncrease : I32EnumAttrCase<"increase", 0>;
+def SetMaxRegisterActionDecrease : I32EnumAttrCase<"decrease", 1>;
+def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max register action",
+ [SetMaxRegisterActionDecrease, SetMaxRegisterActionIncrease]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def SetMaxRegisterActionAttr : EnumAttr<NVVM_Dialect, SetMaxRegisterAction, "action">;
+
+def NVVM_SetMaxRegisterOp : NVVM_PTXBuilder_Op<"setmaxregister.sync.aligned"> {
+ let arguments = (ins I32Attr:$count, SetMaxRegisterActionAttr:$action);
+ let assemblyFormat = "$action $count attr-dict";
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ if(getAction() == NVVM::SetMaxRegisterAction::increase)
+ return std::string("setmaxnreg.inc.sync.aligned.u32 %0;");
+ return std::string("setmaxnreg.dec.sync.aligned.u32 %0;");
+ }
+ }];
+ let hasVerifier = 1;
+}
+
def ShflKindBfly : I32EnumAttrCase<"bfly", 0>;
def ShflKindUp : I32EnumAttrCase<"up", 1>;
def ShflKindDown : I32EnumAttrCase<"down", 2>;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 63ceebb08e5baa7..e2280c398153411 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -213,8 +213,7 @@ void MmaOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print the types of the operands and result.
- p << " : "
- << "(";
+ p << " : " << "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
@@ -954,9 +953,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
- ss << " $" << (regCnt) << ","
- << " $" << (regCnt + 1) << ","
- << " p";
+ ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
if (!outputType.isInteger(32)) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
@@ -964,8 +961,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
if (isF16) {
ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
}
- ss << ";\n"
- << "}\n";
+ ss << ";\n" << "}\n";
ss.flush();
return ptx;
}
@@ -1007,6 +1003,12 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
}
}
+LogicalResult NVVM::SetMaxRegisterOp::verify() {
+ if (getCount() % 8)
+ return emitOpError("new register size must be multiple of 8");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 5fa907850cedf30..fe4c33854485cda 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -611,3 +611,13 @@ llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
llvm.return
}
+
+// -----
+
+func.func @set_max_register() {
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.inc.sync.aligned.u32 [$0];", "n"
+ nvvm.setmaxregister.sync.aligned increase 232
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.dec.sync.aligned.u32 [$0];", "n"
+ nvvm.setmaxregister.sync.aligned decrease 40
+ func.return
+}
|
@llvm/pr-subscribers-mlir-llvm Author: Guray Ozen (grypp) ChangesThis PR introduce Full diff: https://github.com/llvm/llvm-project/pull/73780.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 829fb68549307c8..cbe7c3919d62043 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -400,6 +400,28 @@ def NVVM_FenceScClusterOp : NVVM_Op<"fence.sc.cluster"> {
let assemblyFormat = "attr-dict";
}
+def SetMaxRegisterActionIncrease : I32EnumAttrCase<"increase", 0>;
+def SetMaxRegisterActionDecrease : I32EnumAttrCase<"decrease", 1>;
+def SetMaxRegisterAction : I32EnumAttr<"SetMaxRegisterAction", "NVVM set max register action",
+ [SetMaxRegisterActionDecrease, SetMaxRegisterActionIncrease]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def SetMaxRegisterActionAttr : EnumAttr<NVVM_Dialect, SetMaxRegisterAction, "action">;
+
+def NVVM_SetMaxRegisterOp : NVVM_PTXBuilder_Op<"setmaxregister.sync.aligned"> {
+ let arguments = (ins I32Attr:$count, SetMaxRegisterActionAttr:$action);
+ let assemblyFormat = "$action $count attr-dict";
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ if(getAction() == NVVM::SetMaxRegisterAction::increase)
+ return std::string("setmaxnreg.inc.sync.aligned.u32 %0;");
+ return std::string("setmaxnreg.dec.sync.aligned.u32 %0;");
+ }
+ }];
+ let hasVerifier = 1;
+}
+
def ShflKindBfly : I32EnumAttrCase<"bfly", 0>;
def ShflKindUp : I32EnumAttrCase<"up", 1>;
def ShflKindDown : I32EnumAttrCase<"down", 2>;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 63ceebb08e5baa7..e2280c398153411 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -213,8 +213,7 @@ void MmaOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print the types of the operands and result.
- p << " : "
- << "(";
+ p << " : " << "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
@@ -954,9 +953,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
- ss << " $" << (regCnt) << ","
- << " $" << (regCnt + 1) << ","
- << " p";
+ ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
if (!outputType.isInteger(32)) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
@@ -964,8 +961,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
if (isF16) {
ss << ", $" << (regCnt + 5) << ", $" << (regCnt + 6);
}
- ss << ";\n"
- << "}\n";
+ ss << ";\n" << "}\n";
ss.flush();
return ptx;
}
@@ -1007,6 +1003,12 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
}
}
+LogicalResult NVVM::SetMaxRegisterOp::verify() {
+ if (getCount() % 8)
+ return emitOpError("new register size must be multiple of 8");
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 5fa907850cedf30..fe4c33854485cda 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -611,3 +611,13 @@ llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
llvm.return
}
+
+// -----
+
+func.func @set_max_register() {
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.inc.sync.aligned.u32 [$0];", "n"
+ nvvm.setmaxregister.sync.aligned increase 232
+ //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "setmaxnreg.dec.sync.aligned.u32 [$0];", "n"
+ nvvm.setmaxregister.sync.aligned decrease 40
+ func.return
+}
|
@durga4github would be great if you can review |
✅ With the latest revision this PR passed the C/C++ code formatter. |
This PR introduce `setmaxregister.sync.aligned` Op to increase or decrease the register size. https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg
let arguments = (ins I32Attr:$regCount, SetMaxRegisterActionAttr:$action); | ||
let assemblyFormat = "$action $regCount attr-dict"; | ||
let extraClassDefinition = [{ | ||
std::string $cppClass::getPtx() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please clarify the below questions, for my understanding:
If we happen to add an LLVM intrinsic for this at a later point,
a)
Can we extend the same Op to lower to that instead of inline-PTX ?
(without any changes to the Op's interface)
b)
I believe we can achieve that by overriding the 'hasIntrinsic()' to return true (within the extraClassDeclaration).
Do we need anything more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, I think we should add LLVM intrinsic (this is "the right thing").
a) We can (should) keep the Op's interface same. So nvgpu->nvvm
lowering remains unchanged (and other lowerings we have internally). We can use LLVM intrinsic when we have it.
Let's say we have LLVM intrinsic for this Op, what we need to is to add llvmBuilder
part and remove getPtx
.
def NVVM_SetMaxRegisterOp : NVVM_Op<"setmaxregister"> {
let arguments = (ins I32Attr:$regCount, SetMaxRegisterActionAttr:$action);
let assemblyFormat = "$action $regCount attr-dict";
string llvmBuilder = [{
if(getAction() == NVVM::SetMaxRegisterAction::increase)
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_set_max_register_increase);
else createIntrinsicCall(builder, llvm::Intrinsic::nvvm_set_max_register_descreate);
}];
let hasVerifier = 1;
}
b) I've added a hasIntrinsic()
in BasicPtxBuilderOpInterface
interface to use inline PTX and LLVM intrinsic together. If LLVM contains all necessary intrinsics for the Op, we don't need BasicPtxBuilderOpInterface
and hasIntrinsic()
.
For example you can take a look at mbarrier.init
Op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation, Guray!
I am with you on a).
On b): Yes, I was looking at the mbarrier.init as an example, and wanted to confirm if both the intrinsic and the inline PTX forms can co-exist. It seems they can, without any issues.
(I was running the mbarrier.init Op example and was always generating the intrinsic version.
Later realized I was using mlir-translate and not "mlir-opt with conversion-to-llvm" to see the inline-PTX impl.)
Thanks for all the clarifications!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure happy to answer.
mlir-translate
generates the llvm code. We use llvm's intrinsic here.
We have convert-nvvm-to-llvm
mlir pass. It matches the OPs that implements BasicPtxBuilderOpInterface
and generates inline ptx for them.
This PR introduce
setmaxregister.sync.aligned
Op to increase or decrease the register size.https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg