diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 00057a2269895..0dc3f27781bd3 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -2868,7 +2868,7 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() { return success(); } -LogicalResult spirv::Deserializer::splitConditionalBlocks() { +LogicalResult spirv::Deserializer::splitSelectionHeader() { // Create a copy, so we can modify keys in the original. BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo; for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end(); @@ -2885,7 +2885,7 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() { Operation *terminator = block->getTerminator(); assert(terminator); - if (!isa(terminator)) + if (!isa(terminator)) continue; // Check if the current header block is a merge block of another construct. @@ -2895,10 +2895,10 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() { splitHeaderMergeBlock = true; } - // Do not split a block that only contains a conditional branch, unless it - // is also a merge block of another construct - in that case we want to - // split the block. We do not want two constructs to share header / merge - // block. + // Do not split a block that only contains a conditional branch / switch, + // unless it is also a merge block of another construct - in that case we + // want to split the block. We do not want two constructs to share header / + // merge block. if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) { Block *newBlock = block->splitBlock(terminator); OpBuilder builder(block, block->end()); @@ -2936,7 +2936,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() { logger.startLine() << "\n"; }); - if (failed(splitConditionalBlocks())) { + if (failed(splitSelectionHeader())) { return failure(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 6d09d556c4d02..50c935036158c 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -280,11 +280,11 @@ class Deserializer { return opBuilder.getStringAttr(attrName); } - /// Move a conditional branch into a separate basic block to avoid unnecessary - /// sinking of defs that may be required outside a selection region. This - /// function also ensures that a single block cannot be a header block of one - /// selection construct and the merge block of another. - LogicalResult splitConditionalBlocks(); + /// Move a conditional branch or a switch into a separate basic block to avoid + /// unnecessary sinking of defs that may be required outside a selection + /// region. This function also ensures that a single block cannot be a header + /// block of one selection construct and the merge block of another. + LogicalResult splitSelectionHeader(); //===--------------------------------------------------------------------===// // Type diff --git a/mlir/test/Target/SPIRV/selection_switch.spvasm b/mlir/test/Target/SPIRV/selection_switch.spvasm new file mode 100644 index 0000000000000..81fecf307eb7a --- /dev/null +++ b/mlir/test/Target/SPIRV/selection_switch.spvasm @@ -0,0 +1,69 @@ +; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %} + +; This test is analogous to selection.spv but tests switch op. + +; CHECK: spirv.module Logical GLSL450 requires #spirv.vce { +; CHECK-NEXT: spirv.func @switch({{%.*}}: si32) "None" { +; CHECK: {{%.*}} = spirv.Constant 1.000000e+00 : f32 +; CHECK-NEXT: {{%.*}} = spirv.Undef : vector<3xf32> +; CHECK-NEXT: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32> +; CHECK-NEXT: spirv.Branch ^[[bb:.+]] +; CHECK-NEXT: ^[[bb:.+]]: +; CHECK-NEXT: {{%.*}} = spirv.mlir.selection -> vector<3xf32> { +; CHECK-NEXT: spirv.Switch {{%.*}} : si32, [ +; CHECK-NEXT: default: ^[[bb:.+]]({{%.*}}: vector<3xf32>), +; CHECK-NEXT: 0: ^[[bb:.+]]({{%.*}}: vector<3xf32>), +; CHECK-NEXT: 1: ^[[bb:.+]]({{%.*}}: vector<3xf32>) +; CHECK: ^[[bb:.+]]({{%.*}}: vector<3xf32>): +; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>) +; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>): +; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>) +; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>): +; CHECK-NEXT: spirv.mlir.merge %8 : vector<3xf32> +; CHECK-NEXT } +; CHECK: spirv.Return +; CHECK-NEXT: } +; CHECK: } + + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "main" + OpExecutionMode %main LocalSize 1 1 1 + OpName %switch "switch" + OpName %main "main" + %void = OpTypeVoid + %int = OpTypeInt 32 1 + %1 = OpTypeFunction %void %int + %float = OpTypeFloat 32 + %float_1 = OpConstant %float 1 + %v3float = OpTypeVector %float 3 + %9 = OpUndef %v3float + %float_3 = OpConstant %float 3 + %float_4 = OpConstant %float 4 + %float_2 = OpConstant %float 2 + %25 = OpTypeFunction %void + %switch = OpFunction %void None %1 + %5 = OpFunctionParameter %int + %6 = OpLabel + OpBranch %12 + %12 = OpLabel + %11 = OpCompositeInsert %v3float %float_1 %9 0 + OpSelectionMerge %15 None + OpSwitch %5 %15 0 %13 1 %14 + %13 = OpLabel + %16 = OpPhi %v3float %11 %12 + %18 = OpCompositeInsert %v3float %float_3 %16 1 + OpBranch %15 + %14 = OpLabel + %19 = OpPhi %v3float %11 %12 + %21 = OpCompositeInsert %v3float %float_4 %19 1 + OpBranch %15 + %15 = OpLabel + %22 = OpPhi %v3float %21 %14 %18 %13 %11 %12 + %24 = OpCompositeInsert %v3float %float_2 %22 2 + OpReturn + OpFunctionEnd + %main = OpFunction %void None %25 + %27 = OpLabel + OpReturn + OpFunctionEnd