diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index f9cd58de8915f..68eb56a90a6ab 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -83,6 +83,15 @@ def NVVM_Dialect : Dialect { /// are grid constants. static StringRef getGridConstantAttrName() { return "nvvm.grid_constant"; } + /// Get the name of the attribute used to annotate the `.blocksareclusters` + /// PTX directive for kernel functions. + /// This attribute implies that the grid launch configuration for the + /// corresponding kernel function is specifying the number of clusters + /// instead of the number of thread blocks. This attribute is only + /// allowed for kernel functions and requires nvvm.reqntid and + /// nvvm.cluster_dim attributes. + static StringRef getBlocksAreClustersAttrName() { return "nvvm.blocksareclusters"; } + /// Verify an attribute from this dialect on the argument at 'argIndex' for /// the region at 'regionIndex' on the given operation. Returns failure if /// the verification failed, success otherwise. This hook may optionally be diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index dbcc738b4419f..7cd5ceeff5a1b 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1908,19 +1908,31 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, attrName == NVVMDialect::getReqntidAttrName() || attrName == NVVMDialect::getClusterDimAttrName()) { auto values = llvm::dyn_cast(attr.getValue()); - if (!values || values.empty() || values.size() > 3) + if (!values || values.empty() || values.size() > 3) { return op->emitError() << "'" << attrName << "' attribute must be integer array with maximum 3 index"; + } } // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer // attribute if (attrName == NVVMDialect::getMinctasmAttrName() || attrName == NVVMDialect::getMaxnregAttrName() || attrName == NVVMDialect::getClusterMaxBlocksAttrName()) { - if (!llvm::dyn_cast(attr.getValue())) + if (!llvm::dyn_cast(attr.getValue())) { return op->emitError() << "'" << attrName << "' attribute must be integer constant"; + } + } + // blocksareclusters must be used along with reqntid and cluster_dim + if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) { + if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) || + !op->hasAttr(NVVMDialect::getClusterDimAttrName())) { + return op->emitError() + << "'" << attrName << "' attribute must be used along with " + << "'" << NVVMDialect::getReqntidAttrName() << "' and " + << "'" << NVVMDialect::getClusterDimAttrName() << "'"; + } } return success(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index e67cfed983255..a20701ce75bc0 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -468,7 +468,11 @@ class NVVMDialectLLVMIRTranslationInterface } else if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) { llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel); + } else if (attribute.getName() == + NVVM::NVVMDialect::getBlocksAreClustersAttrName()) { + llvmFunc->addFnAttr("nvvm.blocksareclusters"); } + return success(); } diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir index 33398cfb92429..90cf9b5593054 100644 --- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir @@ -56,6 +56,22 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array} { + llvm.return +} + +// ----- + +// expected-error @below {{'"nvvm.blocksareclusters"' attribute must be used along with 'nvvm.reqntid' and 'nvvm.cluster_dim'}} +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.blocksareclusters, + nvvm.reqntid = array} { + llvm.return +} + +// ----- + llvm.func @nvvm_fence_proxy_acquire(%addr : !llvm.ptr, %size : i32) { // expected-error @below {{'nvvm.fence.proxy.acquire' op uni-directional proxies only support generic for from_proxy attribute}} nvvm.fence.proxy.acquire #nvvm.mem_scope %addr, %size from_proxy=#nvvm.proxy_kind to_proxy=#nvvm.proxy_kind diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index c8ba91efbff4d..554042dee81ae 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -692,7 +692,16 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array, + nvvm.cluster_dim = array} { + llvm.return +} + +// CHECK: define ptx_kernel void @kernel_func() #[[ATTR0:[0-9]+]] +// CHECK: attributes #[[ATTR0]] = { "nvvm.blocksareclusters" "nvvm.cluster_dim"="3,5,7" "nvvm.reqntid"="1,23,32" } // ----- // CHECK: define ptx_kernel void @kernel_func // CHECK: !nvvm.annotations =