Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2868,7 +2868,7 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
return success();
}

LogicalResult spirv::Deserializer::splitConditionalBlocks() {
LogicalResult spirv::Deserializer::splitSelectionHeader() {
// Create a copy, so we can modify keys in the original.
BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
Expand All @@ -2885,7 +2885,7 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
Operation *terminator = block->getTerminator();
assert(terminator);

if (!isa<spirv::BranchConditionalOp>(terminator))
if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
continue;

// Check if the current header block is a merge block of another construct.
Expand All @@ -2895,10 +2895,10 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
splitHeaderMergeBlock = true;
}

// Do not split a block that only contains a conditional branch, unless it
// is also a merge block of another construct - in that case we want to
// split the block. We do not want two constructs to share header / merge
// block.
// Do not split a block that only contains a conditional branch / switch,
// unless it is also a merge block of another construct - in that case we
// want to split the block. We do not want two constructs to share header /
// merge block.
if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
Block *newBlock = block->splitBlock(terminator);
OpBuilder builder(block, block->end());
Expand Down Expand Up @@ -2936,7 +2936,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
logger.startLine() << "\n";
});

if (failed(splitConditionalBlocks())) {
if (failed(splitSelectionHeader())) {
return failure();
}

Expand Down
10 changes: 5 additions & 5 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,11 @@ class Deserializer {
return opBuilder.getStringAttr(attrName);
}

/// Move a conditional branch into a separate basic block to avoid unnecessary
/// sinking of defs that may be required outside a selection region. This
/// function also ensures that a single block cannot be a header block of one
/// selection construct and the merge block of another.
LogicalResult splitConditionalBlocks();
/// Move a conditional branch or a switch into a separate basic block to avoid
/// unnecessary sinking of defs that may be required outside a selection
/// region. This function also ensures that a single block cannot be a header
/// block of one selection construct and the merge block of another.
LogicalResult splitSelectionHeader();

//===--------------------------------------------------------------------===//
// Type
Expand Down
69 changes: 69 additions & 0 deletions mlir/test/Target/SPIRV/selection_switch.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}

; This test is analogous to selection.spv but tests switch op.

; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
; CHECK-NEXT: spirv.func @switch({{%.*}}: si32) "None" {
; CHECK: {{%.*}} = spirv.Constant 1.000000e+00 : f32
; CHECK-NEXT: {{%.*}} = spirv.Undef : vector<3xf32>
; CHECK-NEXT: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32>
; CHECK-NEXT: spirv.Branch ^[[bb:.+]]
; CHECK-NEXT: ^[[bb:.+]]:
; CHECK-NEXT: {{%.*}} = spirv.mlir.selection -> vector<3xf32> {
; CHECK-NEXT: spirv.Switch {{%.*}} : si32, [
; CHECK-NEXT: default: ^[[bb:.+]]({{%.*}}: vector<3xf32>),
; CHECK-NEXT: 0: ^[[bb:.+]]({{%.*}}: vector<3xf32>),
; CHECK-NEXT: 1: ^[[bb:.+]]({{%.*}}: vector<3xf32>)
; CHECK: ^[[bb:.+]]({{%.*}}: vector<3xf32>):
; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>)
; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>):
; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>)
; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>):
; CHECK-NEXT: spirv.mlir.merge %8 : vector<3xf32>
; CHECK-NEXT }
; CHECK: spirv.Return
; CHECK-NEXT: }
; CHECK: }

OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 1 1 1
OpName %switch "switch"
OpName %main "main"
%void = OpTypeVoid
%int = OpTypeInt 32 1
%1 = OpTypeFunction %void %int
%float = OpTypeFloat 32
%float_1 = OpConstant %float 1
%v3float = OpTypeVector %float 3
%9 = OpUndef %v3float
%float_3 = OpConstant %float 3
%float_4 = OpConstant %float 4
%float_2 = OpConstant %float 2
%25 = OpTypeFunction %void
%switch = OpFunction %void None %1
%5 = OpFunctionParameter %int
%6 = OpLabel
OpBranch %12
%12 = OpLabel
%11 = OpCompositeInsert %v3float %float_1 %9 0
OpSelectionMerge %15 None
OpSwitch %5 %15 0 %13 1 %14
%13 = OpLabel
%16 = OpPhi %v3float %11 %12
%18 = OpCompositeInsert %v3float %float_3 %16 1
OpBranch %15
%14 = OpLabel
%19 = OpPhi %v3float %11 %12
%21 = OpCompositeInsert %v3float %float_4 %19 1
OpBranch %15
%15 = OpLabel
%22 = OpPhi %v3float %21 %14 %18 %13 %11 %12
%24 = OpCompositeInsert %v3float %float_2 %22 2
OpReturn
OpFunctionEnd
%main = OpFunction %void None %25
%27 = OpLabel
OpReturn
OpFunctionEnd