Skip to content

Commit

Permalink
[MLIR] Improve KernelOutlining to avoid introducing an extra block (#…
Browse files Browse the repository at this point in the history
…90359)

This fixes a TODO in the code.
  • Loading branch information
joker-eph committed Apr 29, 2024
1 parent cd68d7b commit d566a5c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 20 deletions.
34 changes: 18 additions & 16 deletions mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,24 +241,26 @@ static gpu::GPUFuncOp outlineKernelFuncImpl(gpu::LaunchOp launchOp,
map.map(operand.value(), entryBlock.getArgument(operand.index()));

// Clone the region of the gpu.launch operation into the gpu.func operation.
// TODO: If cloneInto can be modified such that if a mapping for
// a block exists, that block will be used to clone operations into (at the
// end of the block), instead of creating a new block, this would be much
// cleaner.
launchOpBody.cloneInto(&outlinedFuncBody, map);

// Branch from entry of the gpu.func operation to the block that is cloned
// from the entry block of the gpu.launch operation.
Block &launchOpEntry = launchOpBody.front();
Block *clonedLaunchOpEntry = map.lookup(&launchOpEntry);
builder.setInsertionPointToEnd(&entryBlock);
builder.create<cf::BranchOp>(loc, clonedLaunchOpEntry);

outlinedFunc.walk([](gpu::TerminatorOp op) {
OpBuilder replacer(op);
replacer.create<gpu::ReturnOp>(op.getLoc());
op.erase();
});
// Replace the terminator op with returns.
for (Block &block : launchOpBody) {
Block *clonedBlock = map.lookup(&block);
auto terminator = dyn_cast<gpu::TerminatorOp>(clonedBlock->getTerminator());
if (!terminator)
continue;
OpBuilder replacer(terminator);
replacer.create<gpu::ReturnOp>(terminator->getLoc());
terminator->erase();
}

// Splice now the entry block of the gpu.launch operation at the end of the
// gpu.func entry block and erase the redundant block.
Block *clonedLaunchOpEntry = map.lookup(&launchOpBody.front());
entryBlock.getOperations().splice(entryBlock.getOperations().end(),
clonedLaunchOpEntry->getOperations());
clonedLaunchOpEntry->erase();

return outlinedFunc;
}

Expand Down
35 changes: 31 additions & 4 deletions mlir/test/Dialect/GPU/outlining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,41 @@ func.func @launch() {
// CHECK-NEXT: %[[BDIM:.*]] = gpu.block_dim x
// CHECK-NEXT: = gpu.block_dim y
// CHECK-NEXT: = gpu.block_dim z
// CHECK-NEXT: cf.br ^[[BLOCK:.*]]
// CHECK-NEXT: ^[[BLOCK]]:
// CHECK-NEXT: "use"(%[[KERNEL_ARG0]]) : (f32) -> ()
// CHECK-NEXT: "some_op"(%[[BID]], %[[BDIM]]) : (index, index) -> ()
// CHECK-NEXT: = memref.load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1>

// -----

// Verify that we can outline a CFG
// CHECK-LABEL: gpu.func @launchCFG_kernel(
// CHECK: cf.br
// CHECK: gpu.return
func.func @launchCFG() {
%0 = "op"() : () -> (f32)
%1 = "op"() : () -> (memref<?xf32, 1>)
%gDimX = arith.constant 8 : index
%gDimY = arith.constant 12 : index
%gDimZ = arith.constant 16 : index
%bDimX = arith.constant 20 : index
%bDimY = arith.constant 24 : index
%bDimZ = arith.constant 28 : index

gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY,
%grid_z = %gDimZ)
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY,
%block_z = %bDimZ) {
"use"(%0): (f32) -> ()
cf.br ^bb1
^bb1:
"some_op"(%bx, %block_x) : (index, index) -> ()
%42 = memref.load %1[%tx] : memref<?xf32, 1>
gpu.terminator
}
return
}


// -----

// This test checks gpu-out-lining can handle gpu.launch kernel from an llvm.func
Expand Down Expand Up @@ -475,8 +504,6 @@ func.func @launch_cluster() {
// CHECK-NEXT: %[[CDIM:.*]] = gpu.cluster_dim x
// CHECK-NEXT: = gpu.cluster_dim y
// CHECK-NEXT: = gpu.cluster_dim z
// CHECK-NEXT: cf.br ^[[BLOCK:.*]]
// CHECK-NEXT: ^[[BLOCK]]:
// CHECK-NEXT: "use"(%[[KERNEL_ARG0]]) : (f32) -> ()
// CHECK-NEXT: "some_op"(%[[CID]], %[[BID]], %[[BDIM]]) : (index, index, index) -> ()
// CHECK-NEXT: = memref.load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1>
Expand Down

0 comments on commit d566a5c

Please sign in to comment.