diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index d48349630b5bac..bbe16717fa022b 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1573,7 +1573,8 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef operands) { for (unsigned i = 2, e = operands.size(); i < e; i += 2) { uint32_t value = operands[i]; Block *predecessor = getOrCreateBlock(operands[i + 1]); - blockPhiInfo[predecessor].push_back(value); + std::pair predecessorTargetPair{predecessor, curBlock}; + blockPhiInfo[predecessorTargetPair].push_back(value); LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor << " with arg id = " << value << '\n'); } @@ -1853,7 +1854,8 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() { OpBuilder::InsertionGuard guard(opBuilder); for (const auto &info : blockPhiInfo) { - Block *block = info.first; + Block *block = info.first.first; + Block *target = info.first.second; const BlockPhiInfo &phiInfo = info.second; LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n"); LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n"); @@ -1882,6 +1884,24 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() { opBuilder.create(branchOp.getLoc(), branchOp.getTarget(), blockArgs); branchOp.erase(); + } else if (auto branchCondOp = dyn_cast(op)) { + assert((branchCondOp.getTrueBlock() == target || + branchCondOp.getFalseBlock() == target) && + "expected target to be either the true or false target"); + if (target == branchCondOp.trueTarget()) + opBuilder.create( + branchCondOp.getLoc(), branchCondOp.condition(), blockArgs, + branchCondOp.getFalseBlockArguments(), + branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(), + branchCondOp.falseTarget()); + else + opBuilder.create( + branchCondOp.getLoc(), branchCondOp.condition(), + branchCondOp.getTrueBlockArguments(), blockArgs, + branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(), + branchCondOp.getFalseBlock()); + + branchCondOp.erase(); } else { return emitError(unknownLoc, "unimplemented terminator for Phi creation"); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index ac4846d63cad5a..17060dddc91989 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -560,8 +560,10 @@ class Deserializer { // Header block to its merge (and continue) target mapping. BlockMergeInfoMap blockMergeInfo; - // Block to its phi (block argument) mapping. - DenseMap blockPhiInfo; + // For each pair of {predecessor, target} blocks, maps the pair of blocks to + // the list of phi arguments passed from predecessor to target. + DenseMap, BlockPhiInfo> + blockPhiInfo; // Result to value mapping. DenseMap valueMap; diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index ab35315e7f1445..773fa863c08115 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -959,7 +959,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { // OpPhi | result type | result | (value , parent block ) pair // So we need to collect all predecessor blocks and the arguments they send // to this block. - SmallVector, 4> predecessors; + SmallVector, 4> predecessors; for (Block *predecessor : block->getPredecessors()) { auto *terminator = predecessor->getTerminator(); // The predecessor here is the immediate one according to MLIR's IR @@ -971,7 +971,21 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { // structured control flow op's merge block. predecessor = getPhiIncomingBlock(predecessor); if (auto branchOp = dyn_cast(terminator)) { - predecessors.emplace_back(predecessor, branchOp.operand_begin()); + predecessors.emplace_back(predecessor, branchOp.getOperands()); + } else if (auto branchCondOp = + dyn_cast(terminator)) { + Optional blockOperands; + + for (auto successorIdx : + llvm::seq(0, predecessor->getNumSuccessors())) + if (predecessor->getSuccessors()[successorIdx] == block) { + blockOperands = branchCondOp.getSuccessorOperands(successorIdx); + break; + } + + assert(blockOperands && !blockOperands->empty() && + "expected non-empty block operand range"); + predecessors.emplace_back(predecessor, *blockOperands); } else { return terminator->emitError("unimplemented terminator for Phi creation"); } @@ -996,7 +1010,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { phiArgs.push_back(phiID); for (auto predIndex : llvm::seq(0, predecessors.size())) { - Value value = *(predecessors[predIndex].second + argIndex); + Value value = predecessors[predIndex].second[argIndex]; uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId << ") value " << value << ' '); diff --git a/mlir/test/Target/SPIRV/phi.mlir b/mlir/test/Target/SPIRV/phi.mlir index 807783ae74ec44..63236aa495bb42 100644 --- a/mlir/test/Target/SPIRV/phi.mlir +++ b/mlir/test/Target/SPIRV/phi.mlir @@ -286,3 +286,60 @@ spv.module Logical GLSL450 requires #spv.vce { spv.EntryPoint "GLCompute" @fmul_kernel spv.ExecutionMode @fmul_kernel "LocalSize", 32, 1, 1 } + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { +// CHECK-LABEL: @cond_branch_true_argument + spv.func @cond_branch_true_argument() -> () "None" { + %true = spv.Constant true + %zero = spv.Constant 0 : i32 + %one = spv.Constant 1 : i32 +// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}}, %{{.*}} : i32, i32), ^[[false1:.*]] + spv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1 +// CHECK: [[true1]](%{{.*}}: i32, %{{.*}}: i32) + ^true1(%arg0: i32, %arg1: i32): + spv.Return +// CHECK: [[false1]]: + ^false1: + spv.Return + } +} + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { +// CHECK-LABEL: @cond_branch_false_argument + spv.func @cond_branch_false_argument() -> () "None" { + %true = spv.Constant true + %zero = spv.Constant 0 : i32 + %one = spv.Constant 1 : i32 +// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]], ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32) + spv.BranchConditional %true, ^true1, ^false1(%zero, %zero: i32, i32) +// CHECK: [[true1]]: + ^true1: + spv.Return +// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32): + ^false1(%arg0: i32, %arg1: i32): + spv.Return + } +} + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { +// CHECK-LABEL: @cond_branch_true_and_false_argument + spv.func @cond_branch_true_and_false_argument() -> () "None" { + %true = spv.Constant true + %zero = spv.Constant 0 : i32 + %one = spv.Constant 1 : i32 +// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}} : i32), ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32) + spv.BranchConditional %true, ^true1(%one: i32), ^false1(%zero, %zero: i32, i32) +// CHECK: [[true1]](%{{.*}}: i32): + ^true1(%arg0: i32): + spv.Return +// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32): + ^false1(%arg1: i32, %arg2: i32): + spv.Return + } +}