Skip to content

Commit

Permalink
[mlir][gpu] Use known_block_size to set maxntid for NVVM target (#…
Browse files Browse the repository at this point in the history
…77301)

Setting thread block size with `maxntid` on the kernel has great
performance benefits. In this way, downstream PTX compiler can do better
register allocation.

MLIR's `gpu.launch` and `gpu.launch_func` already has an attribute
(`known_block_size`) that keeps the thread block size when it is known.
This PR simply uses this attribute to set `maxntid`.
  • Loading branch information
grypp committed Jan 8, 2024
1 parent 2edce42 commit 763109e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 6 deletions.
20 changes: 19 additions & 1 deletion mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Expand Up @@ -85,8 +85,26 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// Add a dialect specific kernel attribute in addition to GPU kernel
// attribute. The former is necessary for further translation while the
// latter is expected by gpu.launch_func.
if (gpuFuncOp.isKernel())
if (gpuFuncOp.isKernel()) {
attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());

// Set the block size attribute if it is present.
if (kernelBlockSizeAttributeName.has_value()) {
std::optional<int32_t> dimX =
gpuFuncOp.getKnownBlockSize(gpu::Dimension::x);
std::optional<int32_t> dimY =
gpuFuncOp.getKnownBlockSize(gpu::Dimension::y);
std::optional<int32_t> dimZ =
gpuFuncOp.getKnownBlockSize(gpu::Dimension::z);
if (dimX.has_value() || dimY.has_value() || dimZ.has_value()) {
// If any of the dimensions are missing, fill them in with 1.
attributes.emplace_back(
kernelBlockSizeAttributeName.value(),
rewriter.getI32ArrayAttr(
{dimX.value_or(1), dimY.value_or(1), dimZ.value_or(1)}));
}
}
}
auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
LLVM::Linkage::External, /*dsoLocal=*/false, /*cconv=*/LLVM::CConv::C,
Expand Down
13 changes: 9 additions & 4 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
Expand Up @@ -36,13 +36,15 @@ struct GPUDynamicSharedMemoryOpLowering
};

struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
GPUFuncOpLowering(const LLVMTypeConverter &converter,
unsigned allocaAddrSpace, unsigned workgroupAddrSpace,
StringAttr kernelAttributeName)
GPUFuncOpLowering(
const LLVMTypeConverter &converter, unsigned allocaAddrSpace,
unsigned workgroupAddrSpace, StringAttr kernelAttributeName,
std::optional<StringAttr> kernelBlockSizeAttributeName = std::nullopt)
: ConvertOpToLLVMPattern<gpu::GPUFuncOp>(converter),
allocaAddrSpace(allocaAddrSpace),
workgroupAddrSpace(workgroupAddrSpace),
kernelAttributeName(kernelAttributeName) {}
kernelAttributeName(kernelAttributeName),
kernelBlockSizeAttributeName(kernelBlockSizeAttributeName) {}

LogicalResult
matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
Expand All @@ -56,6 +58,9 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {

/// The attribute name to use instead of `gpu.kernel`.
StringAttr kernelAttributeName;

/// The attribute name to to set block size
std::optional<StringAttr> kernelBlockSizeAttributeName;
};

/// The lowering of gpu.printf to a call to HIP hostcalls
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Expand Up @@ -352,7 +352,9 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
/*workgroupAddrSpace=*/
static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace),
StringAttr::get(&converter.getContext(),
NVVM::NVVMDialect::getKernelFuncAttrName()));
NVVM::NVVMDialect::getKernelFuncAttrName()),
StringAttr::get(&converter.getContext(),
NVVM::NVVMDialect::getMaxntidAttrName()));

populateOpPatterns<math::AbsFOp>(converter, patterns, "__nv_fabsf",
"__nv_fabs");
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
Expand Up @@ -627,6 +627,15 @@ gpu.module @test_module_31 {
}
}

gpu.module @gpumodule {
// CHECK-LABEL: func @kernel_with_block_size()
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = [128 : i32, 1 : i32, 1 : i32]}
gpu.func @kernel_with_block_size() kernel attributes {gpu.known_block_size = array<i32: 128, 1, 1>} {
gpu.return
}
}


module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
Expand Down

0 comments on commit 763109e

Please sign in to comment.