diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index b0c781c7aff11..9468927021495 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -185,6 +185,30 @@ class ConvertWhileOpTypes }; } // namespace +namespace { +class ConvertIndexSwitchOpTypes + : public Structural1ToNConversionPattern { +public: + using Structural1ToNConversionPattern::Structural1ToNConversionPattern; + + std::optional + convertSourceOp(IndexSwitchOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + TypeRange dstTypes) const { + auto newOp = + IndexSwitchOp::create(rewriter, op.getLoc(), dstTypes, op.getArg(), + op.getCases(), op.getNumCases()); + + for (unsigned i = 0u; i < op.getNumRegions(); i++) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + } + return newOp; + } +}; +} // namespace + namespace { // When the result types of a ForOp/IfOp get changed, the operand types of the // corresponding yield op need to be changed. In order to trigger the @@ -220,18 +244,19 @@ void mlir::scf::populateSCFStructuralTypeConversions( const TypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add( - typeConverter, patterns.getContext(), benefit); + ConvertWhileOpTypes, ConvertConditionOpTypes, + ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext(), + benefit); } void mlir::scf::populateSCFStructuralTypeConversionTarget( const TypeConverter &typeConverter, ConversionTarget &target) { - target.addDynamicallyLegalOp( + target.addDynamicallyLegalOp( [&](Operation *op) { return typeConverter.isLegal(op->getResults()); }); target.addDynamicallyLegalOp([&](scf::YieldOp op) { // We only have conversions for a subset of ops that use scf.yield // terminators. - if (!isa(op->getParentOp())) + if (!isa(op->getParentOp())) return true; return typeConverter.isLegal(op.getOperands()); }); diff --git a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir index f5d6a08b7de31..515de5502f322 100644 --- a/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir @@ -86,3 +86,47 @@ func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024x } return %0: tensor<1024xf32, #SparseVector> } + +// CHECK-LABEL: func.func @index_switch( +// CHECK-SAME: %[[PRED:.*0]]: index, +// CHECK-SAME: %[[VAL_A_1:.*1]]: memref, +// CHECK-SAME: %[[VAL_A_2:.*2]]: memref, +// CHECK-SAME: %[[VAL_A_3:.*3]]: memref, +// CHECK-SAME: %[[VAL_A_4:.*4]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_B_1:.*5]]: memref, +// CHECK-SAME: %[[VAL_B_2:.*6]]: memref, +// CHECK-SAME: %[[VAL_B_3:.*7]]: memref, +// CHECK-SAME: %[[VAL_B_4:.*8]]: !sparse_tensor.storage_specifier +// CHECK-SAME: %[[VAL_C_1:.*9]]: memref, +// CHECK-SAME: %[[VAL_C_2:.*10]]: memref, +// CHECK-SAME: %[[VAL_C_3:.*11]]: memref, +// CHECK-SAME: %[[VAL_C_4:.*12]]: !sparse_tensor.storage_specifier + +// CHECK: %[[RES:.*]]:4 = scf.index_switch %[[PRED]] +// CHECK-SAME: -> memref, memref, memref, !sparse_tensor.storage_specifier +// CHECK: case 1 { +// CHECK: scf.yield %[[VAL_A_1]], %[[VAL_A_2]], %[[VAL_A_3]], %[[VAL_A_4]] +// CHECK: case 2 { +// CHECK: scf.yield %[[VAL_B_1]], %[[VAL_B_2]], %[[VAL_B_3]], %[[VAL_B_4]] +// CHECK: default { +// CHECK: scf.yield %[[VAL_C_1]], %[[VAL_C_2]], %[[VAL_C_3]], %[[VAL_C_4]] + +// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 : +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier + +func.func @index_switch(%pred: index, %a: tensor<5xf32, #SparseVector>, + %b: tensor<5xf32, #SparseVector>, + %c: tensor<5xf32, #SparseVector>) -> tensor<5xf32, #SparseVector> { + %0 = scf.index_switch %pred -> tensor<5xf32, #SparseVector> + case 1 { + scf.yield %a : tensor<5xf32, #SparseVector> + } + case 2 { + scf.yield %b : tensor<5xf32, #SparseVector> + } + default { + scf.yield %c : tensor<5xf32, #SparseVector> + } + + return %0 : tensor<5xf32, #SparseVector> +}