Skip to content

Commit

Permalink
Add support for OpPhi in loop header block
Browse files Browse the repository at this point in the history
During deserialization, the loop header block will be moved into the
spv.loop's region. If the loop header block has block arguments,
we need to make sure it is correctly carried over to the block where
the new spv.loop resides.

During serialization, we need to make sure block arguments from the
spv.loop's entry block are not silently dropped.

PiperOrigin-RevId: 280021777
  • Loading branch information
antiagainst authored and tensorflower-gardener committed Nov 12, 2019
1 parent 626e1fd commit b259c26
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 9 deletions.
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
Expand Up @@ -300,7 +300,12 @@ def SPV_LoopOp : SPV_Op<"loop", [InFunctionScope]> {

let regions = (region AnyRegion:$body);

let builders = [OpBuilder<"Builder *builder, OperationState &state">];

let extraClassDeclaration = [{
// Returns the entry block.
Block *getEntryBlock();

// Returns the loop header block.
Block *getHeaderBlock();

Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Expand Up @@ -1442,6 +1442,13 @@ static LogicalResult verify(spirv::LoadOp loadOp) {
// spv.loop
//===----------------------------------------------------------------------===//

void spirv::LoopOp::build(Builder *builder, OperationState &state) {
state.addAttribute("loop_control",
builder->getI32IntegerAttr(
static_cast<uint32_t>(spirv::LoopControl::None)));
state.addRegion();
}

static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
// TODO(antiagainst): support loop control properly
Builder builder = parser.getBuilder();
Expand Down Expand Up @@ -1557,6 +1564,11 @@ static LogicalResult verify(spirv::LoopOp loopOp) {
return success();
}

Block *spirv::LoopOp::getEntryBlock() {
assert(!body().empty() && "op region should not be empty!");
return &body().front();
}

Block *spirv::LoopOp::getHeaderBlock() {
assert(!body().empty() && "op region should not be empty!");
// The second block is the loop header block.
Expand Down
37 changes: 29 additions & 8 deletions mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
Expand Up @@ -1700,9 +1700,8 @@ spirv::LoopOp ControlFlowStructurizer::createLoopOp() {
// merge block so that the newly created LoopOp will be inserted there.
OpBuilder builder(&mergeBlock->front());

auto control = builder.getI32IntegerAttr(
static_cast<uint32_t>(spirv::LoopControl::None));
auto loopOp = builder.create<spirv::LoopOp>(location, control);
// TODO(antiagainst): handle loop control properly
auto loopOp = builder.create<spirv::LoopOp>(location);
loopOp.addEntryAndMergeBlock();

return loopOp;
Expand Down Expand Up @@ -1810,10 +1809,25 @@ LogicalResult ControlFlowStructurizer::structurizeImpl() {
headerBlock->replaceAllUsesWith(mergeBlock);

if (isLoop) {
// The loop selection/loop header block may have block arguments. Since now
// we place the selection/loop op inside the old merge block, we need to
// make sure the old merge block has the same block argument list.
assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported");
for (BlockArgument *blockArg : headerBlock->getArguments()) {
mergeBlock->addArgument(blockArg->getType());
}

// If the loop header block has block arguments, make sure the spv.branch op
// matches.
SmallVector<Value *, 4> blockArgs;
if (!headerBlock->args_empty())
blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};

// The loop entry block should have a unconditional branch jumping to the
// loop header block.
builder.setInsertionPointToEnd(&body.front());
builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock));
builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
ArrayRef<Value *>(blockArgs));
}

// All the blocks cloned into the SelectionOp/LoopOp's region can now be
Expand Down Expand Up @@ -1901,16 +1915,23 @@ LogicalResult Deserializer::structurizeControlFlow() {

for (const auto &info : blockMergeInfo) {
auto *headerBlock = info.first;
LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << "\n");
LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n");
LLVM_DEBUG(headerBlock->print(llvm::dbgs()));

const auto &mergeInfo = info.second;

auto *mergeBlock = mergeInfo.mergeBlock;
auto *continueBlock = mergeInfo.continueBlock;
assert(mergeBlock && "merge block cannot be nullptr");
LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << "\n");
if (!mergeBlock->args_empty())
return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n");
LLVM_DEBUG(mergeBlock->print(llvm::dbgs()));

auto *continueBlock = mergeInfo.continueBlock;
if (continueBlock) {
LLVM_DEBUG(llvm::dbgs()
<< "[cf] continue block " << continueBlock << "\n");
<< "[cf] continue block " << continueBlock << ":\n");
LLVM_DEBUG(continueBlock->print(llvm::dbgs()));
}

if (failed(ControlFlowStructurizer::structurize(unknownLoc, headerBlock,
Expand Down
11 changes: 11 additions & 0 deletions mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
Expand Up @@ -1515,6 +1515,17 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
// afterwards.
encodeInstructionInto(functions, spirv::Opcode::OpBranch, {headerID});

// We omit the LoopOp's entry block and start serialization from the loop
// header block. The entry block should not contain any additional ops other
// than a single spv.Branch that jumps to the loop header block. However,
// the spv.Branch can contain additional block arguments. Those block
// arguments must come from out of the loop using implicit capture. We will
// need to query the <id> for the value sent and the <id> for the incoming
// parent block. For the latter, we need to make sure this block is
// registered. The value sent should come from the block this loop resides in.
blockIDMap[loopOp.getEntryBlock()] =
getBlockID(loopOp.getOperation()->getBlock());

// Emit the loop header block, which dominates all other blocks, first. We
// need to emit an OpLoopMerge instruction before the loop header block's
// terminator.
Expand Down
46 changes: 45 additions & 1 deletion mlir/test/Dialect/SPIRV/Serialization/loop.mlir
@@ -1,4 +1,4 @@
// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
// RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s

// Single loop

Expand Down Expand Up @@ -61,6 +61,50 @@ spv.module "Logical" "GLSL450" {

// -----

spv.module "Logical" "GLSL450" {
spv.globalVariable @GV1 bind(0, 0) : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
spv.globalVariable @GV2 bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
func @loop_kernel() {
%0 = spv._address_of @GV1 : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
%1 = spv.constant 0 : i32
%2 = spv.AccessChain %0[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
%3 = spv._address_of @GV2 : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
%5 = spv.AccessChain %3[%1] : !spv.ptr<!spv.struct<!spv.array<10 x f32 [4]> [0]>, StorageBuffer>
%6 = spv.constant 4 : i32
%7 = spv.constant 42 : i32
%8 = spv.constant 2 : i32
// CHECK: spv.Branch ^bb1(%{{.*}} : i32)
// CHECK-NEXT: ^bb1(%[[OUTARG:.*]]: i32):
// CHECK-NEXT: spv.loop {
spv.loop {
// CHECK-NEXT: spv.Branch ^bb1(%[[OUTARG]] : i32)
spv.Branch ^header(%6 : i32)
// CHECK-NEXT: ^bb1(%[[HEADARG:.*]]: i32):
^header(%9: i32):
%10 = spv.SLessThan %9, %7 : i32
// CHECK: spv.BranchConditional %{{.*}}, ^bb2, ^bb3
spv.BranchConditional %10, ^body, ^merge
// CHECK-NEXT: ^bb2: // pred: ^bb1
^body:
%11 = spv.AccessChain %2[%9] : !spv.ptr<!spv.array<10 x f32 [4]>, StorageBuffer>
%12 = spv.Load "StorageBuffer" %11 : f32
%13 = spv.AccessChain %5[%9] : !spv.ptr<!spv.array<10 x f32 [4]>, StorageBuffer>
spv.Store "StorageBuffer" %13, %12 : f32
// CHECK: %[[ADD:.*]] = spv.IAdd
%14 = spv.IAdd %9, %8 : i32
// CHECK-NEXT: spv.Branch ^bb1(%[[ADD]] : i32)
spv.Branch ^header(%14 : i32)
// CHECK-NEXT: ^bb3:
^merge:
// CHECK-NEXT: spv._merge
spv._merge
}
spv.Return
}
spv.EntryPoint "GLCompute" @loop_kernel
spv.ExecutionMode @loop_kernel "LocalSize", 1, 1, 1
} attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}

// TODO(antiagainst): re-enable this after fixing the assertion failure.
// Nested loop

Expand Down

0 comments on commit b259c26

Please sign in to comment.