diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir index 7cfd73c01bf4b..f597da6e4caf4 100644 --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -215,7 +215,7 @@ func @select(%arg : index, %arg2 : i32) -> i32 { // CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32 // CHECK: %[[SELECTOR:.*]] = llvm.trunc %[[SELECTVALUE]] : i{{.*}} to i32 -// CHECK: llvm.switch %[[SELECTOR]], ^bb5 [ +// CHECK: llvm.switch %[[SELECTOR]] : i32, ^bb5 [ // CHECK: 1: ^bb1(%[[C0]] : i32), // CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, [[IDX]], i32), // CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32), @@ -260,7 +260,7 @@ func @select_rank(%arg : i32, %arg2 : i32) -> i32 { // CHECK: %[[C0:.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: %[[C1:.*]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: %[[C2:.*]] = llvm.mlir.constant(3 : i32) : i32 -// CHECK: llvm.switch %[[SELECTVALUE]], ^bb5 [ +// CHECK: llvm.switch %[[SELECTVALUE]] : i32, ^bb5 [ // CHECK: 1: ^bb1(%[[C0]] : i32), // CHECK: 2: ^bb2(%[[C2]], %[[SELECTVALUE]], %[[ARG1]] : i32, i32, i32), // CHECK: 3: ^bb3(%[[ARG1]], %[[C2]] : i32, i32), diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 6bd64edf44c4f..055975ef58bcf 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -724,7 +724,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, NoSideEffect]> { let arguments = (ins - I32:$value, + AnyInteger:$value, Variadic:$defaultOperands, VariadicOfVariadic:$caseOperands, OptionalAttr:$case_values, @@ -738,9 +738,9 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", let verifier = [{ return ::verify(*this); }]; let assemblyFormat = [{ - $value `,` + $value `:` type($value) `,` $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? - `[` `\n` custom($case_values, $caseDestinations, + `[` `\n` custom(ref(type($value)), $case_values, $caseDestinations, $caseOperands, type($caseOperands)) `]` attr-dict }]; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index dd2ffdabd8c5f..0afc64b2ffce0 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -269,20 +269,21 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, /// ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? static ParseResult parseSwitchOpCases( - OpAsmParser &parser, ElementsAttr &caseValues, + OpAsmParser &parser, Type flagType, ElementsAttr &caseValues, SmallVectorImpl &caseDestinations, SmallVectorImpl> &caseOperands, SmallVectorImpl> &caseOperandTypes) { - SmallVector values; - int32_t value = 0; + SmallVector values; + unsigned bitWidth = flagType.getIntOrFloatBitWidth(); do { + int64_t value = 0; OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); if (values.empty() && !integerParseResult.hasValue()) return success(); if (!integerParseResult.hasValue() || integerParseResult.getValue()) return failure(); - values.push_back(value); + values.push_back(APInt(bitWidth, value)); Block *destination; SmallVector operands; @@ -299,11 +300,13 @@ static ParseResult parseSwitchOpCases( caseOperandTypes.emplace_back(operandTypes); } while (!parser.parseOptionalComma()); - caseValues = parser.getBuilder().getI32VectorAttr(values); + ShapedType caseValueType = + VectorType::get(static_cast(values.size()), flagType); + caseValues = DenseIntElementsAttr::get(caseValueType, values); return success(); } -static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, +static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, ElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir index 7f5500c875e7a..7e47448e7f3ba 100644 --- a/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-coro-to-llvm.mlir @@ -64,7 +64,7 @@ func @coro_suspend() { // CHECK: %[[FINAL:.*]] = llvm.mlir.constant(false) : i1 // CHECK: %[[RET:.*]] = llvm.intr.coro.suspend %[[STATE]], %[[FINAL]] // CHECK: %[[SEXT:.*]] = llvm.sext %[[RET]] : i8 to i32 - // CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]] + // CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]] // CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]] // CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]] async.coro.suspend %2, ^suspend, ^resume, ^cleanup diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir index eb8ddbb13e5d3..46ff7501f4b89 100644 --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -49,7 +49,7 @@ func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) { // Decide the next block based on the code returned from suspend. // CHECK: %[[SEXT:.*]] = llvm.sext %[[SUSPENDED]] : i8 to i32 -// CHECK: llvm.switch %[[SEXT]], ^[[SUSPEND:[b0-9]+]] +// CHECK: llvm.switch %[[SEXT]] : i32, ^[[SUSPEND:[b0-9]+]] // CHECK-NEXT: 0: ^[[RESUME:[b0-9]+]] // CHECK-NEXT: 1: ^[[CLEANUP:[b0-9]+]] diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir index 9e9636a20ab83..7d0942ca8691b 100644 --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -592,3 +592,31 @@ func @select_2dvector(%arg0 : vector<4x3xi1>, %arg1 : vector<4x3xi32>, %arg2 : v %0 = select %arg0, %arg1, %arg2 : vector<4x3xi1>, vector<4x3xi32> std.return } + +// ----- + +// CHECK-LABEL: func @switchi8( +func @switchi8(%arg0 : i8) -> i32 { +switch %arg0 : i8, [ + default: ^bb1, + 42: ^bb1, + 43: ^bb3 + ] +^bb1: + %c_1 = arith.constant 1 : i32 + std.return %c_1 : i32 +^bb3: + %c_42 = arith.constant 42 : i32 + std.return %c_42: i32 +} +// CHECK: llvm.switch %arg0 : i8, ^bb1 [ +// CHECK-NEXT: 42: ^bb1, +// CHECK-NEXT: 43: ^bb2 +// CHECK-NEXT: ] +// CHECK: ^bb1: // 2 preds: ^bb0, ^bb0 +// CHECK-NEXT: %[[E0:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK-NEXT: llvm.return %[[E0]] : i32 +// CHECK: ^bb2: // pred: ^bb0 +// CHECK-NEXT: %[[E1:.+]] = llvm.mlir.constant(42 : i32) : i32 +// CHECK-NEXT: llvm.return %[[E1]] : i32 +// CHECK-NEXT: } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 3f07f173ec875..fd9b5765fa2f2 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -805,7 +805,7 @@ module attributes {llvm.data_layout = "#vjkr32"} { func @switch_wrong_number_of_weights(%arg0 : i32) { // expected-error@+1 {{expects number of branch weights to match number of successors: 3 vs 2}} - llvm.switch %arg0, ^bb1 [ + llvm.switch %arg0 : i32, ^bb1 [ 42: ^bb2(%arg0, %arg0 : i32, i32) ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>} diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index 8efd14ee4f597..b931c9bb69e86 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -84,12 +84,12 @@ func @ops(%arg0: i32, %arg1: f32, // CHECK: %{{.*}} = llvm.mlir.constant(42 : i64) : i47 %22 = llvm.mlir.undef : !llvm.struct<(i32, f64, i32)> %23 = llvm.mlir.constant(42) : i47 - // CHECK: llvm.switch %0, ^[[BB3]] [ + // CHECK: llvm.switch %0 : i32, ^[[BB3]] [ // CHECK-NEXT: 1: ^[[BB4:.*]], // CHECK-NEXT: 2: ^[[BB5:.*]], // CHECK-NEXT: 3: ^[[BB6:.*]] // CHECK-NEXT: ] - llvm.switch %0, ^bb3 [ + llvm.switch %0 : i32, ^bb3 [ 1: ^bb4, 2: ^bb5, 3: ^bb6 @@ -97,24 +97,24 @@ func @ops(%arg0: i32, %arg1: f32, // CHECK: ^[[BB3]] ^bb3: -// CHECK: llvm.switch %0, ^[[BB7:.*]] [ +// CHECK: llvm.switch %0 : i32, ^[[BB7:.*]] [ // CHECK-NEXT: ] - llvm.switch %0, ^bb7 [ + llvm.switch %0 : i32, ^bb7 [ ] // CHECK: ^[[BB4]] ^bb4: - llvm.switch %0, ^bb7 [ + llvm.switch %0 : i32, ^bb7 [ ] // CHECK: ^[[BB5]] ^bb5: - llvm.switch %0, ^bb7 [ + llvm.switch %0 : i32, ^bb7 [ ] // CHECK: ^[[BB6]] ^bb6: - llvm.switch %0, ^bb7 [ + llvm.switch %0 : i32, ^bb7 [ ] // CHECK: ^[[BB7]] diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 7677e59c3cd91..f5b6d60662ad2 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1560,7 +1560,7 @@ llvm.func @switch_args(%arg0: i32) -> i32 { // CHECK-NEXT: i32 -1, label %[[SWITCHCASE_bb2:[0-9]+]] // CHECK-NEXT: i32 1, label %[[SWITCHCASE_bb3:[0-9]+]] // CHECK-NEXT: ] - llvm.switch %arg0, ^bb1 [ + llvm.switch %arg0 : i32, ^bb1 [ -1: ^bb2(%0 : i32), 1: ^bb3(%1, %2 : i32, i32) ] @@ -1590,7 +1590,7 @@ llvm.func @switch_weights(%arg0: i32) -> i32 { %1 = llvm.mlir.constant(23 : i32) : i32 %2 = llvm.mlir.constant(29 : i32) : i32 // CHECK: !prof ![[SWITCH_WEIGHT_NODE:[0-9]+]] - llvm.switch %arg0, ^bb1(%0 : i32) [ + llvm.switch %arg0 : i32, ^bb1(%0 : i32) [ 9: ^bb2(%1, %2 : i32, i32), 99: ^bb3 ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}