Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,30 @@ class ConvertWhileOpTypes
};
} // namespace

namespace {
class ConvertIndexSwitchOpTypes
: public Structural1ToNConversionPattern<IndexSwitchOp,
ConvertIndexSwitchOpTypes> {
public:
using Structural1ToNConversionPattern::Structural1ToNConversionPattern;

std::optional<IndexSwitchOp>
convertSourceOp(IndexSwitchOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
TypeRange dstTypes) const {
auto newOp =
IndexSwitchOp::create(rewriter, op.getLoc(), dstTypes, op.getArg(),
op.getCases(), op.getNumCases());

for (unsigned i = 0u; i < op.getNumRegions(); i++) {
auto &dstRegion = newOp.getRegion(i);
rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
}
return newOp;
}
};
} // namespace

namespace {
// When the result types of a ForOp/IfOp get changed, the operand types of the
// corresponding yield op need to be changed. In order to trigger the
Expand Down Expand Up @@ -220,18 +244,19 @@ void mlir::scf::populateSCFStructuralTypeConversions(
const TypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes,
ConvertWhileOpTypes, ConvertConditionOpTypes>(
typeConverter, patterns.getContext(), benefit);
ConvertWhileOpTypes, ConvertConditionOpTypes,
ConvertIndexSwitchOpTypes>(typeConverter, patterns.getContext(),
benefit);
}

void mlir::scf::populateSCFStructuralTypeConversionTarget(
const TypeConverter &typeConverter, ConversionTarget &target) {
target.addDynamicallyLegalOp<ForOp, IfOp>(
target.addDynamicallyLegalOp<ForOp, IfOp, IndexSwitchOp>(
[&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
// We only have conversions for a subset of ops that use scf.yield
// terminators.
if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
if (!isa<ForOp, IfOp, WhileOp, IndexSwitchOp>(op->getParentOp()))
return true;
return typeConverter.isLegal(op.getOperands());
});
Expand Down
44 changes: 44 additions & 0 deletions mlir/test/Dialect/SparseTensor/scf_1_N_conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,47 @@ func.func @while(%arg0: tensor<1024xf32, #SparseVector>, %c: i1) -> tensor<1024x
}
return %0: tensor<1024xf32, #SparseVector>
}

// CHECK-LABEL: func.func @index_switch(
// CHECK-SAME: %[[PRED:.*0]]: index,
// CHECK-SAME: %[[VAL_A_1:.*1]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_A_2:.*2]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_A_3:.*3]]: memref<?xf32>,
// CHECK-SAME: %[[VAL_A_4:.*4]]: !sparse_tensor.storage_specifier
// CHECK-SAME: %[[VAL_B_1:.*5]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_B_2:.*6]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_B_3:.*7]]: memref<?xf32>,
// CHECK-SAME: %[[VAL_B_4:.*8]]: !sparse_tensor.storage_specifier
// CHECK-SAME: %[[VAL_C_1:.*9]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_C_2:.*10]]: memref<?xindex>,
// CHECK-SAME: %[[VAL_C_3:.*11]]: memref<?xf32>,
// CHECK-SAME: %[[VAL_C_4:.*12]]: !sparse_tensor.storage_specifier

// CHECK: %[[RES:.*]]:4 = scf.index_switch %[[PRED]]
// CHECK-SAME: -> memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier
// CHECK: case 1 {
// CHECK: scf.yield %[[VAL_A_1]], %[[VAL_A_2]], %[[VAL_A_3]], %[[VAL_A_4]]
// CHECK: case 2 {
// CHECK: scf.yield %[[VAL_B_1]], %[[VAL_B_2]], %[[VAL_B_3]], %[[VAL_B_4]]
// CHECK: default {
// CHECK: scf.yield %[[VAL_C_1]], %[[VAL_C_2]], %[[VAL_C_3]], %[[VAL_C_4]]

// CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3 :
// CHECK-SAME: memref<?xindex>, memref<?xindex>, memref<?xf32>, !sparse_tensor.storage_specifier

func.func @index_switch(%pred: index, %a: tensor<5xf32, #SparseVector>,
%b: tensor<5xf32, #SparseVector>,
%c: tensor<5xf32, #SparseVector>) -> tensor<5xf32, #SparseVector> {
%0 = scf.index_switch %pred -> tensor<5xf32, #SparseVector>
case 1 {
scf.yield %a : tensor<5xf32, #SparseVector>
}
case 2 {
scf.yield %b : tensor<5xf32, #SparseVector>
}
default {
scf.yield %c : tensor<5xf32, #SparseVector>
}

return %0 : tensor<5xf32, #SparseVector>
}