diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md index 716dd7773aefa..dd68e6ec8b7b8 100644 --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -566,7 +566,7 @@ merge block. For example, for the given function ```c++ -void loop(bool cond) { +void if(bool cond) { int x = 0; if (cond) { x = 1; @@ -605,6 +605,62 @@ func.func @selection(%cond: i1) -> () { } ``` +Similarly, for the give function with a `switch` statement + +```c++ +void switch(int selector) { + int x = 0; + switch (selector) { + case 0: + x = 2; + break; + case 1: + x = 3; + break; + default: + x = 1; + break; + } + // ... +} +``` + +It will be represented as + +```mlir +func.func @selection(%selector: i32) -> () { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %two = spirv.Constant 2: i32 + %three = spirv.Constant 3: i32 + %var = spirv.Variable init(%zero) : !spirv.ptr + + spirv.mlir.selection { + spirv.Switch %selector : i32, [ + default: ^default, + 0: ^case0, + 1: ^case1 + ] + ^default: + spirv.Store "Function" %var, %one : i32 + spirv.Branch ^merge + + ^case0: + spirv.Store "Function" %var, %two : i32 + spirv.Branch ^merge + + ^case1: + spirv.Store "Function" %var, %three : i32 + spirv.Branch ^merge + + ^merge: + spirv.mlir.merge + } + + // ... +} +``` + The selection can return values by yielding them with `spirv.mlir.merge`. This mechanism allows values defined within the selection region to be used outside of it. Without this, values that were sunk into the selection region, but used outside, would diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index b628f1a3f7b20..7b363fac6e627 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4531,6 +4531,7 @@ def SPIRV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerg def SPIRV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>; def SPIRV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>; def SPIRV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>; +def SPIRV_OC_OpSwitch : I32EnumAttrCase<"OpSwitch", 251>; def SPIRV_OC_OpKill : I32EnumAttrCase<"OpKill", 252>; def SPIRV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; def SPIRV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; @@ -4681,7 +4682,7 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor, SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge, SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional, - SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue, + SPIRV_OC_OpSwitch, SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue, SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd, SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin, SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td index acb6467132be9..27c9add7d43af 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td @@ -242,6 +242,112 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [ }]; } +// ----- + +def SPIRV_SwitchOp : SPIRV_Op<"Switch", + [AttrSizedOperandSegments, InFunctionScope, + DeclareOpInterfaceMethods, + Pure, Terminator]> { + let summary = [{ + Multi-way branch to one of the operand label . + }]; + + let description = [{ + Selector must have a type of OpTypeInt. Selector is compared for equality to + the Target literals. + + Default must be the of a label. If Selector does not equal any of the + Target literals, control flow branches to the Default label . + + Target must be alternating scalar integer literals and the of a label. + If Selector equals a literal, control flow branches to the following label + . It is invalid for any two literal to be equal to each other. If Selector + does not equal any literal, control flow branches to the Default label . + Each literal is interpreted with the type of Selector: The bit width of + Selector’s type is the width of each literal’s type. If this width is not a + multiple of 32-bits and the OpTypeInt Signedness is set to 1, the literal values + are interpreted as being sign extended. + + If Selector is an OpUndef, behavior is undefined. + + This instruction must be the last instruction in a block. + + #### Example: + + ```mlir + spirv.Switch %selector : si32, [ + default: ^bb1(%a : i32), + 0: ^bb1(%b : i32), + 1: ^bb3(%c : i32) + ] + ``` + }]; + + let arguments = (ins + SPIRV_Integer:$selector, + Variadic:$defaultOperands, + VariadicOfVariadic:$targetOperands, + OptionalAttr:$literals, + DenseI32ArrayAttr:$case_operand_segments + ); + + let results = (outs); + + let successors = (successor AnySuccessor:$defaultTarget, + VariadicSuccessor:$targets); + + let builders = [ + OpBuilder<(ins "Value":$selector, + "Block *":$defaultTarget, + "ValueRange":$defaultOperands, + CArg<"ArrayRef", "{}">:$literals, + CArg<"BlockRange", "{}">:$targets, + CArg<"ArrayRef", "{}">:$targetOperands)>, + OpBuilder<(ins "Value":$selector, + "Block *":$defaultTarget, + "ValueRange":$defaultOperands, + CArg<"ArrayRef", "{}">:$literals, + CArg<"BlockRange", "{}">:$targets, + CArg<"ArrayRef", "{}">:$targetOperands)>, + OpBuilder<(ins "Value":$selector, + "Block *":$defaultTarget, + "ValueRange":$defaultOperands, + CArg<"DenseIntElementsAttr", "{}">:$literals, + CArg<"BlockRange", "{}">:$targets, + CArg<"ArrayRef", "{}">:$targetOperands)> + ]; + + let assemblyFormat = [{ + $selector `:` type($selector) `,` `[` `\n` + custom(ref(type($selector)),$defaultTarget, + $defaultOperands, + type($defaultOperands), + $literals, + $targets, + $targetOperands, + type($targetOperands)) + `]` + attr-dict + }]; + + let extraClassDeclaration = [{ + /// Return the operands for the target block at the given index. + OperandRange getTargetOperands(unsigned index) { + return getTargetOperands()[index]; + } + + /// Return a mutable range of operands for the target block at the + /// given index. + MutableOperandRange getTargetOperandsMutable(unsigned index) { + return getTargetOperandsMutable()[index]; + } + }]; + + let autogenSerialization = 0; + let hasVerifier = 1; +} + + // ----- def SPIRV_KillOp : SPIRV_Op<"Kill", [Terminator]> { diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index f0b46e61965f4..a846d7e60024c 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -219,6 +219,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() { return getArgumentsMutable(); } +//===----------------------------------------------------------------------===// +// spirv.Switch +//===----------------------------------------------------------------------===// + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + DenseIntElementsAttr literals, BlockRange targets, + ArrayRef targetOperands) { + build(builder, result, selector, defaultOperands, targetOperands, literals, + defaultTarget, targets); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + ArrayRef literals, BlockRange targets, + ArrayRef targetOperands) { + DenseIntElementsAttr literalsAttr; + if (!literals.empty()) { + ShapedType literalType = VectorType::get( + static_cast(literals.size()), selector.getType()); + literalsAttr = DenseIntElementsAttr::get(literalType, literals); + } + build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr, + targets, targetOperands); +} + +void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector, + Block *defaultTarget, ValueRange defaultOperands, + ArrayRef literals, BlockRange targets, + ArrayRef targetOperands) { + DenseIntElementsAttr literalsAttr; + if (!literals.empty()) { + ShapedType literalType = VectorType::get( + static_cast(literals.size()), selector.getType()); + literalsAttr = DenseIntElementsAttr::get(literalType, literals); + } + build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr, + targets, targetOperands); +} + +LogicalResult SwitchOp::verify() { + std::optional literals = getLiterals(); + BlockRange targets = getTargets(); + + if (!literals && targets.empty()) + return success(); + + Type selectorType = getSelector().getType(); + Type literalType = literals->getType().getElementType(); + if (literalType != selectorType) + return emitOpError() << "'selector' type (" << selectorType + << ") should match literals type (" << literalType + << ")"; + + if (literals && literals->size() != static_cast(targets.size())) + return emitOpError() << "number of literals (" << literals->size() + << ") should match number of targets (" + << targets.size() << ")"; + return success(); +} + +SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) { + assert(index < getNumSuccessors() && "invalid successor index"); + return SuccessorOperands(index == 0 ? getDefaultOperandsMutable() + : getTargetOperandsMutable(index - 1)); +} + +Block *SwitchOp::getSuccessorForOperands(ArrayRef operands) { + std::optional literals = getLiterals(); + + if (!literals) + return getDefaultTarget(); + + SuccessorRange targets = getTargets(); + if (auto value = dyn_cast_or_null(operands.front())) { + for (auto [index, literal] : llvm::enumerate(literals->getValues())) + if (literal == value.getValue()) + return targets[index]; + return getDefaultTarget(); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // spirv.mlir.loop //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp index 2f3a28ff16173..8575487ff52cc 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp @@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, } } +/// Adapted from the cf.switch implementation. +/// ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? +/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* +static ParseResult parseSwitchOpCases( + OpAsmParser &parser, Type &selectorType, Block *&defaultTarget, + SmallVectorImpl &defaultOperands, + SmallVectorImpl &defaultOperandTypes, DenseIntElementsAttr &literals, + SmallVectorImpl &targets, + SmallVectorImpl> + &targetOperands, + SmallVectorImpl> &targetOperandTypes) { + if (parser.parseKeyword("default") || parser.parseColon() || + parser.parseSuccessor(defaultTarget)) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None, + /*allowResultNumber=*/false) || + parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen()) + return failure(); + } + + SmallVector values; + unsigned bitWidth = selectorType.getIntOrFloatBitWidth(); + while (succeeded(parser.parseOptionalComma())) { + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + values.push_back(APInt(bitWidth, value, /*isSigned=*/true)); + + Block *target; + SmallVector operands; + SmallVector operandTypes; + if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target))) + return failure(); + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parser.parseOperandList(operands, + OpAsmParser::Delimiter::None)) || + failed(parser.parseColonTypeList(operandTypes)) || + failed(parser.parseRParen())) + return failure(); + } + targets.push_back(target); + targetOperands.emplace_back(operands); + targetOperandTypes.emplace_back(operandTypes); + } + + if (!values.empty()) { + ShapedType literalType = + VectorType::get(static_cast(values.size()), selectorType); + literals = DenseIntElementsAttr::get(literalType, values); + } + return success(); +} + +static void +printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType, + Block *defaultTarget, OperandRange defaultOperands, + TypeRange defaultOperandTypes, DenseIntElementsAttr literals, + SuccessorRange targets, OperandRangeRange targetOperands, + const TypeRangeRange &targetOperandTypes) { + p << " default: "; + p.printSuccessorAndUseList(defaultTarget, defaultOperands); + + if (!literals) + return; + + for (auto [index, literal] : llvm::enumerate(literals.getValues())) { + p << ','; + p.printNewline(); + p << " "; + p << literal.getLimitedValue(); + p << ": "; + p.printSuccessorAndUseList(targets[index], targetOperands[index]); + } + p.printNewline(); +} + } // namespace mlir::spirv // TablenGen'erated operation definitions. diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index c27f9aa91332c..5b04a14a78036 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -248,6 +248,8 @@ LogicalResult spirv::Deserializer::processInstruction( return processLoopMerge(operands); case spirv::Opcode::OpPhi: return processPhi(operands); + case spirv::Opcode::OpSwitch: + return processSwitch(operands); case spirv::Opcode::OpUndef: return processUndef(operands); default: diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 6492708694cc5..252be796488c5 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -2292,6 +2292,38 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef operands) { return success(); } +LogicalResult spirv::Deserializer::processSwitch(ArrayRef operands) { + if (!curBlock) + return emitError(unknownLoc, "OpSwitch must appear in a block"); + + if (operands.size() < 2) + return emitError(unknownLoc, "OpSwitch must at least specify selector and " + "a default target"); + + if (operands.size() % 2) + return emitError(unknownLoc, + "OpSwitch must at have an even number of operands: " + "selector, default target and any number of literal and " + "label pairs"); + + Value selector = getValue(operands[0]); + Block *defaultBlock = getOrCreateBlock(operands[1]); + Location loc = createFileLineColLoc(opBuilder); + + SmallVector literals; + SmallVector blocks; + for (unsigned i = 2, e = operands.size(); i < e; i += 2) { + literals.push_back(operands[i]); + blocks.push_back(getOrCreateBlock(operands[i + 1])); + } + + SmallVector targetOperands(blocks.size(), {}); + spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock, + ArrayRef(), literals, blocks, targetOperands); + + return success(); +} + namespace { /// A class for putting all blocks in a structured selection/loop in a /// spirv.mlir.selection/spirv.mlir.loop op. diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 6027f1ac94c23..243e6fd70ae43 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -472,6 +472,9 @@ class Deserializer { /// Processes a SPIR-V OpPhi instruction with the given `operands`. LogicalResult processPhi(ArrayRef operands); + /// Processes a SPIR-V OpSwitch instruction with the given `operands`. + LogicalResult processSwitch(ArrayRef operands); + /// Creates block arguments on predecessors previously recorded when handling /// OpPhi instructions. LogicalResult wireUpBlockArgument(); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index 85e92c7ced394..6397d2c005c16 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -775,6 +775,27 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) { return success(); } +LogicalResult Serializer::processSwitchOp(spirv::SwitchOp switchOp) { + uint32_t selectorID = getValueID(switchOp.getSelector()); + uint32_t defaultLabelID = getOrCreateBlockID(switchOp.getDefaultTarget()); + SmallVector arguments{selectorID, defaultLabelID}; + + std::optional literals = switchOp.getLiterals(); + BlockRange targets = switchOp.getTargets(); + if (literals) { + for (auto [literal, target] : llvm::zip_equal(*literals, targets)) { + arguments.push_back(literal.getLimitedValue()); + uint32_t targetLabelID = getOrCreateBlockID(target); + arguments.push_back(targetLabelID); + } + } + + if (failed(emitDebugLine(functionBody, switchOp.getLoc()))) + return failure(); + encodeInstructionInto(functionBody, spirv::Opcode::OpSwitch, arguments); + return success(); +} + LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) { auto varName = addressOfOp.getVariable(); auto variableID = getVariableID(varName); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 29ed5a4fc139e..4e03a809bd0bc 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -1579,6 +1579,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) { .Case([&](spirv::SpecConstantOperationOp op) { return processSpecConstantOperationOp(op); }) + .Case([&](spirv::SwitchOp op) { return processSwitchOp(op); }) .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h index add372b19b5af..6e79c133eb6af 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -304,6 +304,8 @@ class Serializer { LogicalResult processBranchOp(spirv::BranchOp branchOp); + LogicalResult processSwitchOp(spirv::SwitchOp switchOp); + //===--------------------------------------------------------------------===// // Operations //===--------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir index 8e29ff6679068..b70bb40dae97f 100644 --- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir @@ -795,6 +795,53 @@ func.func @selection(%cond: i1) -> () { // ----- +func.func @selection_switch(%selector: i32) -> () { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %two = spirv.Constant 2: i32 + %three = spirv.Constant 3: i32 + %var = spirv.Variable init(%zero) : !spirv.ptr + + // CHECK: spirv.mlir.selection { + spirv.mlir.selection { + // CHECK-NEXT: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1, + // CHECK-NEXT: 0: ^bb2, + // CHECK-NEXT: 1: ^bb3 + spirv.Switch %selector : i32, [ + default: ^default, + 0: ^case0, + 1: ^case1 + ] + // CHECK: ^bb1 + ^default: + spirv.Store "Function" %var, %one : i32 + // CHECK: spirv.Branch ^bb4 + spirv.Branch ^merge + + // CHECK: ^bb2 + ^case0: + spirv.Store "Function" %var, %two : i32 + // CHECK: spirv.Branch ^bb4 + spirv.Branch ^merge + + // CHECK: ^bb3 + ^case1: + spirv.Store "Function" %var, %three : i32 + // CHECK: spirv.Branch ^bb4 + spirv.Branch ^merge + + // CHECK: ^bb4 + ^merge: + // CHECK-NEXT: spirv.mlir.merge + spirv.mlir.merge + } + + spirv.Return +} + +// ----- + // CHECK-LABEL: @empty_region func.func @empty_region() -> () { // CHECK: spirv.mlir.selection @@ -918,3 +965,171 @@ func.func @kill() { // CHECK: spirv.Kill spirv.Kill } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.Switch +//===----------------------------------------------------------------------===// + +func.func @switch(%selector: i32) -> () { + // CHECK: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1, + // CHECK-NEXT: 0: ^bb2, + // CHECK-NEXT: 1: ^bb3, + // CHECK-NEXT: 2: ^bb4 + spirv.Switch %selector : i32, [ + default: ^default, + 0: ^case0, + 1: ^case1, + 2: ^case2 + ] +^default: + spirv.Branch ^merge + +^case0: + spirv.Branch ^merge + +^case1: + spirv.Branch ^merge + +^case2: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +func.func @switch_only_default(%selector: i32) -> () { + // CHECK: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1 + spirv.Switch %selector : i32, [ + default: ^default + ] +^default: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +func.func @switch_operands(%selector : i32, %operand : i32) { + // CHECK: spirv.Switch {{%.*}} : i32, [ + // CHECK-NEXT: default: ^bb1({{%.*}} : i32), + // CHECK-NEXT: 0: ^bb2({{%.*}} : i32), + // CHECK-NEXT: 1: ^bb3({{%.*}} : i32) + spirv.Switch %selector : i32, [ + default: ^default(%operand : i32), + 0: ^case0(%operand : i32), + 1: ^case1(%operand : i32) + ] +^default(%argd : i32): + spirv.Branch ^merge + +^case0(%arg0 : i32): + spirv.Branch ^merge + +^case1(%arg1 : i32): + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_float_selector(%selector: f32) -> () { + // expected-error@+1 {{expected builtin.integer, but found 'f32'}} + spirv.Switch %selector : f32, [ + default: ^default + ] +^default: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_float_selector(%selector: i32) -> () { + // expected-error@+3 {{expected integer value}} + spirv.Switch %selector : i32, [ + default: ^default, + 0.0: ^case0 + ] +^default: + spirv.Branch ^merge + +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_missing_default(%selector: i32) -> () { + // expected-error@+2 {{expected 'default'}} + spirv.Switch %selector : i32, [ + 0: ^case0 + ] +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_default_no_target(%selector: i32) -> () { + // expected-error@+2 {{expected block name}} + spirv.Switch %selector : i32, [ + default: + ] +^default: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_case_no_target(%selector: i32) -> () { + // expected-error@+3 {{expected block name}} + spirv.Switch %selector : i32, [ + default: ^default, + 0: + ] +^default: + spirv.Branch ^merge + +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + +// ----- + +func.func @switch_missing_operand_type(%selector: i32) -> () { + %0 = spirv.Constant 0 : i32 + // expected-error@+2 {{expected ':'}} + spirv.Switch %selector : i32, [ + default: ^default (%0), + 0.0: ^case0 + ] +^default(%argd : i32): + spirv.Branch ^merge + +^case 0: + spirv.Branch ^merge + +^merge: + spirv.Return +} + diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir index 12daf68538d0a..3f762920015aa 100644 --- a/mlir/test/Target/SPIRV/selection.mlir +++ b/mlir/test/Target/SPIRV/selection.mlir @@ -220,3 +220,71 @@ spirv.module Logical GLSL450 requires #spirv.vce { spirv.EntryPoint "GLCompute" @main spirv.ExecutionMode @main "LocalSize", 1, 1, 1 } + +// ----- + +// Selection with switch + +spirv.module Logical GLSL450 requires #spirv.vce { +// CHECK-LABEL: @selection_switch + spirv.func @selection_switch(%selector: i32) -> () "None" { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %two = spirv.Constant 2: i32 + %three = spirv.Constant 3: i32 + %four = spirv.Constant 4: i32 +// CHECK: {{%.*}} = spirv.Variable init({{%.*}}) : !spirv.ptr + %var = spirv.Variable init(%zero) : !spirv.ptr +// CHECK: spirv.mlir.selection { + spirv.mlir.selection { +// CHECK-NEXT: spirv.Switch {{%.*}} : i32, [ +// CHECK-NEXT: default: ^[[DEFAULT:.+]], +// CHECK-NEXT: 0: ^[[CASE0:.+]], +// CHECK-NEXT: 1: ^[[CASE1:.+]], +// CHECK-NEXT: 2: ^[[CASE2:.+]] + spirv.Switch %selector : i32, [ + default: ^default, + 0: ^case0, + 1: ^case1, + 2: ^case2 + ] +// CHECK: ^[[DEFAULT]] + ^default: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %one : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[CASE0]] + ^case0: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %two : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[CASE1]] + ^case1: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %three : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[CASE2]] + ^case2: +// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %four : i32 +// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]] + spirv.Branch ^merge +// CHECK-NEXT: ^[[MERGE]] + ^merge: +// CHECK-NEXT: spirv.mlir.merge + spirv.mlir.merge +// CHECK-NEXT: } + } +// CHECK-NEXT: spirv.Return + spirv.Return + } + + spirv.func @main() -> () "None" { + spirv.Return + } + spirv.EntryPoint "GLCompute" @main + spirv.ExecutionMode @main "LocalSize", 1, 1, 1 +}