Skip to content

Commit e9ed4af

Browse files
[TOSA] Add legalization for aten.index_select (#3760)
- Add Torch to TOSA legalization for aten.index_select - Fix createOneDimTfIndices function in TosaLegalizeCommon.cpp to correctly convert Torch indices to TF-style indices, which is used in convertGatherNdOp - Update e2e tests in xfail_sets.py - Update basic.mlir with new LIT test for aten.index_select Signed-off-by: Justin Ngo <justin.ngo@arm.com> Change-Id: I52519246183949353a3cf22f0a685fe3df8ec8ff Signed-off-by: Justin Ngo <justin.ngo@arm.com>
1 parent 2374b9e commit e9ed4af

File tree

4 files changed

+230
-57
lines changed

4 files changed

+230
-57
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3821,6 +3821,124 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
38213821
return success();
38223822
}
38233823

3824+
template <>
3825+
LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
3826+
AtenIndexSelectOp op, OpAdaptor adaptor,
3827+
ConversionPatternRewriter &rewriter) const {
3828+
// Not a tensor type.
3829+
auto input = adaptor.getSelf();
3830+
auto inputType = dyn_cast<RankedTensorType>(input.getType());
3831+
if (!inputType)
3832+
return rewriter.notifyMatchFailure(
3833+
op, "Only RankedTensorType inputs are currently supported");
3834+
3835+
auto index = adaptor.getIndex();
3836+
auto indexType = dyn_cast<RankedTensorType>(index.getType());
3837+
3838+
if (!indexType)
3839+
return rewriter.notifyMatchFailure(
3840+
op, "Only RankedTensorType indices are currently supported");
3841+
3842+
auto inputShape = inputType.getShape();
3843+
int inputRank = inputType.getRank();
3844+
3845+
if (indexType.getRank() == 0)
3846+
return rewriter.notifyMatchFailure(
3847+
op, "Rank 0 index tensor is currently not supported");
3848+
3849+
// Dynamic shape check
3850+
if (!inputType.hasStaticShape() || !indexType.hasStaticShape())
3851+
return rewriter.notifyMatchFailure(
3852+
op, "AtenIndexSelectOp: support for dynamic input "
3853+
"shape not implemented");
3854+
3855+
// index i64 to i32 for tosa compatible
3856+
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
3857+
index = rewriter.create<tosa::CastOp>(
3858+
op->getLoc(),
3859+
RankedTensorType::get(indexType.getShape(),
3860+
rewriter.getIntegerType(32)),
3861+
index);
3862+
}
3863+
3864+
// Get positive dim
3865+
int64_t dim;
3866+
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
3867+
return rewriter.notifyMatchFailure(
3868+
op, "Value `dim` should be a torch constant int");
3869+
dim = toPositiveDim(dim, inputRank);
3870+
if (!isValidDim(dim, inputRank))
3871+
return rewriter.notifyMatchFailure(op, "Value `dim` is invalid");
3872+
3873+
// Get the output type
3874+
auto outType = getTypeConverter()->convertType(op.getType());
3875+
3876+
// Reshape and expand the index tensor to have same rank and same dimensions
3877+
// (except for the targeted dim) as the input
3878+
//
3879+
// For example:
3880+
// Input shape = (4, 5, 6)
3881+
// Index vector shape = (2)
3882+
// Targeted dim = 1
3883+
// Reshaped and expanded index vector shape = (4, 2, 6)
3884+
//
3885+
// By reshaping and expanding the index vector, we can supply it into the
3886+
// gather op to mimic the functionality of aten.index_select
3887+
SmallVector<int64_t> indicesInputRankShape;
3888+
for (int64_t i = 0; i < inputRank; i++) {
3889+
if (i == dim) {
3890+
indicesInputRankShape.push_back(indexType.getShape()[0]);
3891+
} else {
3892+
indicesInputRankShape.push_back(1);
3893+
}
3894+
}
3895+
3896+
auto indicesInputRankType =
3897+
RankedTensorType::get(makeShapeLLVMCompatible(indicesInputRankShape),
3898+
rewriter.getIntegerType(32));
3899+
3900+
auto reshapedIndices = rewriter.create<tosa::ReshapeOp>(
3901+
op->getLoc(), indicesInputRankType, index,
3902+
rewriter.getDenseI64ArrayAttr(indicesInputRankShape));
3903+
3904+
SmallVector<int64_t> tileShape(indicesInputRankShape);
3905+
SmallVector<int64_t> expandedIndicesShape(indicesInputRankShape);
3906+
for (int64_t i = 0; i < inputRank; i++) {
3907+
if (tileShape[i] == 1 && i != dim) {
3908+
tileShape[i] = inputShape[i];
3909+
expandedIndicesShape[i] = inputShape[i];
3910+
} else {
3911+
tileShape[i] = 1;
3912+
}
3913+
}
3914+
3915+
auto tileType =
3916+
RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape),
3917+
rewriter.getIntegerType(32));
3918+
3919+
auto expandedIndices = rewriter.create<tosa::TileOp>(
3920+
op->getLoc(), tileType, reshapedIndices.getResult(),
3921+
rewriter.getDenseI64ArrayAttr(tileShape));
3922+
3923+
// convert torch style index and dim into tf style indices
3924+
// tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64>
3925+
auto indicesTf = tosa::convertTorchIndexToTfIndices(
3926+
rewriter, op, input, expandedIndices.getResult(), dim);
3927+
if (!indicesTf)
3928+
return rewriter.notifyMatchFailure(
3929+
op, "Convert TorchIndex To TfIndices failed");
3930+
3931+
// do the tf gathernd algorithm with tf style indices as input.
3932+
auto result =
3933+
tosa::convertGatherNdOp(rewriter, op, outType, input, indicesTf.value());
3934+
3935+
if (!result) {
3936+
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");
3937+
}
3938+
rewriter.replaceOp(op, {result.value()});
3939+
return success();
3940+
}
3941+
38243942
template <>
38253943
LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
38263944
AtenIndexPutHackedTwinOp op, OpAdaptor adaptor,
@@ -6240,6 +6358,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
62406358
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
62416359
INSERT_ATENOP_PATTERN(AtenTrilOp);
62426360
INSERT_ATENOP_PATTERN(AtenDiagonalOp);
6361+
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
62436362
#undef INSERT_ATENOP_PATTERN
62446363

62456364
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,56 +23,71 @@ namespace tosa {
2323

2424
using namespace mlir::torch::Torch;
2525

26+
// This function is a helper for `convertTorchIndexToTfIndices`.
27+
//
28+
// We convert PyTorch index to TensorFlow-style indices so that we can use
29+
// `convertGatherNdOp` and `convertScatterNdOp` functions, which lower Gather
30+
// and Scatter operators to TOSA using TensorFlow-style indices.
31+
// The difference between PyTorch/ONNX Gather/Scatter and TensorFlow
32+
// Gather/Scatter ops is that PyTorch/ONNX take in the dimension that you want
33+
// to gather/scatter elements, while in TensorFlow, the indices point directly
34+
// to positions that you want to gather/scatter elements.
2635
std::optional<Value>
2736
createOneDimTfIndices(PatternRewriter &rewriter, Operation *op,
2837
SmallVector<int64_t> indicesOneDimShape, int32_t dim,
2938
ArrayRef<int64_t> indexShape) {
3039
unsigned indexRank = indexShape.size();
3140
SmallVector<int32_t> indicesVec; // input vec to create tosaConstant
3241
SmallVector<int32_t> indicesMetaElement; // torch.meshgrid inputs
33-
int indicesMetaElementRepeatTimes{1}; // For torch.stack(torch.meshgrid)
3442

3543
// Create torch.meshgrid inputs
3644
// Example: indexShape=[1,4,2]
3745
// dim0: indicesMetaElement = torch.arange(0, 1) = [0]
3846
// dim1: indicesMetaElement = torch.arange(0, 4) = [0,1,2,3]
3947
// dim2: indicesMetaElement = torch.arange(0, 2) = [0,1]
40-
for (int i = 0; i < indexShape[dim]; i++) {
48+
for (int i = 0; i < indexShape[dim]; i++)
4149
indicesMetaElement.push_back(i);
42-
}
43-
44-
// Compute total number of meta element repeat times:
45-
// = product(indexShape[0:dim]) x product(indexShape[dim+1:-1]), skip dim
46-
// dim0: indicesMetaElementRepeatTimes = 1 x 4*2 = 8
47-
// dim1: indicesMetaElementRepeatTimes = 1 *1 x 2 = 2
48-
// dim2: indicesMetaElementRepeatTimes = 1 *1*4 = 4
49-
for (int i = 0; i < static_cast<int>(indexRank); i++) {
50-
if (i == dim) {
51-
continue;
52-
} else {
53-
indicesMetaElementRepeatTimes *= indexShape[i];
54-
}
55-
}
5650

57-
if (dim != static_cast<int>(indexShape.size()) - 1) {
58-
// Create one dim indices for index except for last dim
59-
// Create indices raw vector.
60-
// torch.stack(torch.meshgrid)
61-
// dim0: indicesVec = [0 0 0 0 0 0 0 0]
62-
// dim0: indicesVec = [0 0 1 1 2 2 3 3]
51+
int preDimMetaElementRepeatTimes = 1;
52+
int postDimMetaElementRepeatTimes = 1;
53+
54+
// Compute total number of times meta element range should repeat
55+
// = product(indexShape[0:dim])
56+
// dim0: preDimMetaElementRepeatTimes = 1
57+
// dim1: preDimMetaElementRepeatTimes = 1
58+
// dim2: preDimMetaElementRepeatTimes = 1 x 4 = 4
59+
for (int i = 0; i < dim; i++)
60+
preDimMetaElementRepeatTimes *= indexShape[i];
61+
62+
// Compute total number of times meta element repeat
63+
// = product(indexShape[dim+1:indexRank])
64+
// dim0: postDimMetaElementRepeatTimes = 4 x 2 = 8
65+
// dim1: postDimMetaElementRepeatTimes = 2
66+
// dim2: postDimMetaElementRepeatTimes = 1
67+
for (int i = dim + 1; i < static_cast<int>(indexRank); i++)
68+
postDimMetaElementRepeatTimes *= indexShape[i];
69+
70+
// Example using dim1:
71+
// preDimMetaElementRepeatTimes = 1
72+
// postDimMetaElementRepeatTimes = 2
73+
// Using postDimMetaElementRepeatTimes, we get the meta element range:
74+
// [0 0 1 1 2 2 3 3]
75+
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
76+
// [0 0 1 1 2 2 3 3]
77+
//
78+
// Let's use a clearer example:
79+
// indexShape = [3, 4, 2]
80+
// Target dim = 1
81+
// => preDimMetaElementRepeatTimes = 3
82+
// postDimMetaElementRepeatTimes = 2
83+
// Using postDimMetaElementRepeatTimes, we get the meta element range:
84+
// [0 0 1 1 2 2]
85+
// Using preDimMetaElementRepeatTimes, we get the full one dim indices:
86+
// [0 0 1 1 2 2 0 0 1 1 2 2 0 0 1 1 2 2]
87+
for (int i = 0; i < preDimMetaElementRepeatTimes; i++) {
6388
for (size_t elementId = 0; elementId < indicesMetaElement.size();
6489
elementId++) {
65-
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
66-
indicesVec.push_back(indicesMetaElement[elementId]);
67-
}
68-
}
69-
} else { // Create the one dim indices for last dim of index
70-
// Create indices raw vector
71-
// dim2: indicesVec= [0 1 0 1 0 1 0 1]
72-
// Caution: indicesVec != [0 0 0 0 1 1 1 1]
73-
for (int i = 0; i < indicesMetaElementRepeatTimes; i++) {
74-
for (size_t elementId = 0; elementId < indicesMetaElement.size();
75-
elementId++) {
90+
for (int j = 0; j < postDimMetaElementRepeatTimes; j++) {
7691
indicesVec.push_back(indicesMetaElement[elementId]);
7792
}
7893
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,6 +1663,17 @@
16631663
# Write the TOSA set as a "passing" set as it is very early in development
16641664
# and very few tests work yet.
16651665
TOSA_PASS_SET = {
1666+
"AtenLinalgCrossBroadcast_basic",
1667+
"AtenLinalgCrossCustomDim_basic",
1668+
"AtenLinalgCrossFloat_basic",
1669+
"AtenLinalgCrossInt_basic",
1670+
"AtenLinalgCrossNegativeDim_basic",
1671+
"BinaryCrossEntropyWithLogitsStaticModule_basic",
1672+
"IndexSelectNegativeDimModule_basic",
1673+
"IndexSelectSingleIdxModule_basic",
1674+
"IndexSelectTwoIdxModule_basic",
1675+
"IndexSelectWholeDimensionModule_basic",
1676+
"IndexSelectWholeTensorModule_basic",
16661677
"DiagonalWithStaticShapeModule_basic",
16671678
"EinsumStaticDiagonalDimensionModule_basic",
16681679
"ElementwiseAtenFloorDivideBroadcastModule_basic",
@@ -2342,6 +2353,13 @@
23422353
}
23432354
) - {
23442355
### Test failing in make_fx_tosa but not in tosa
2356+
"ChunkListUnpackUneven_Module_basic",
2357+
"ChunkListUnpack_Module_basic",
2358+
"SplitTensorGetItem_Module_basic",
2359+
"SplitTensorLastSmallerModule_basic",
2360+
"SplitTensorListUnpackModule_basic",
2361+
"SplitTensorNegativeDimModule_basic",
2362+
"SplitWithSizesListUnpackModule_basic",
23452363
# Dynamic shape, has extra unsupported broadcast ops
23462364
"Matmul_3d",
23472365
"MatmulStaticBroadcast_basic",
@@ -3205,6 +3223,17 @@
32053223
}
32063224

32073225
FX_IMPORTER_TOSA_XFAIL_SET = {
3226+
"ChunkListUnpackDynamic_Module_basic",
3227+
"ChunkListUnpackUnevenDynamic_Module_basic",
3228+
"ChunkListUnpackUneven_Module_basic",
3229+
"ChunkListUnpack_Module_basic",
3230+
"SplitTensorGetItem_Module_basic",
3231+
"SplitTensorLastSmallerModule_basic",
3232+
"SplitTensorListUnpackModule_basic",
3233+
"SplitTensorNegativeDimModule_basic",
3234+
"SplitWithSizesListUnpackModule_basic",
3235+
"SplitWithSizes_Module_basic",
3236+
"ElementwiseCreateComplexModule_basic",
32083237
"AdaptiveMaxPool1dDimOneStatic_basic",
32093238
"AtenPolarDoubleModule_basic",
32103239
"AtenPolarFloatModule_basic",
@@ -3302,12 +3331,6 @@
33023331
"AtenIntTensorCharDtypeModule_basic",
33033332
"AtenItemFpOpModule_basic",
33043333
"AtenItemIntOpModule_basic",
3305-
"AtenLinalgCrossBroadcast_basic",
3306-
"AtenLinalgCrossCustomDim_basic",
3307-
"AtenLinalgCrossDynamic_basic",
3308-
"AtenLinalgCrossFloat_basic",
3309-
"AtenLinalgCrossInt_basic",
3310-
"AtenLinalgCrossNegativeDim_basic",
33113334
"AtenMatmulQMixedSigni8Transpose_basic",
33123335
"AtenMatmulQMixedSigni8_basic",
33133336
"AtenMatmulQint8MV_basic",
@@ -3551,15 +3574,7 @@
35513574
"IndexPutImpl3DFloatAccumulateModule_basic",
35523575
"IndexPutImpl3DFloatNonAccumulateModule_basic",
35533576
"IndexPutImplIndexWithNoneModule_basic",
3554-
"IndexSelectDynamicIndexSizeModule_basic",
3555-
"IndexSelectDynamicInputSizeModule_basic",
3556-
"IndexSelectDynamicModulebasic",
3557-
"IndexSelectNegativeDimModule_basic",
35583577
"IndexSelectRank0IdxModule_basic",
3559-
"IndexSelectSingleIdxModule_basic",
3560-
"IndexSelectTwoIdxModule_basic",
3561-
"IndexSelectWholeDimensionModule_basic",
3562-
"IndexSelectWholeTensorModule_basic",
35633578
"IndexTensorNegativeIndexModule_basic",
35643579
"InterpolateDynamicModule_sizes_bilinear",
35653580
"InterpolateDynamicModule_sizes_nearest",
@@ -3848,6 +3863,8 @@
38483863
}
38493864

38503865
ONNX_TOSA_XFAIL_SET = {
3866+
"ElementwiseCreateComplexModule_basic",
3867+
"ReduceAllDimFloatModule_basic",
38513868
"AdaptiveMaxPool1dDimOneStatic_basic",
38523869
"ScaledDotProductAttentionDifferentCausalModule_basic",
38533870
"HstackBasicComplexModule_basic",
@@ -4269,7 +4286,6 @@
42694286
"ElementwiseWhereSelfModule_basic",
42704287
"EmbeddingModule1DIndices_basic",
42714288
"EmbeddingModuleF16_basic",
4272-
"EmbeddingModuleI32Static_basic",
42734289
"EmbeddingModuleI32_basic",
42744290
"EmbeddingModuleI64_basic",
42754291
"EmptyLikeMemoryFormatModule_basic",
@@ -4363,12 +4379,6 @@
43634379
"IndexSelectDynamicIndexSizeModule_basic",
43644380
"IndexSelectDynamicInputSizeModule_basic",
43654381
"IndexSelectDynamicModulebasic",
4366-
"IndexSelectNegativeDimModule_basic",
4367-
"IndexSelectRank0IdxModule_basic",
4368-
"IndexSelectSingleIdxModule_basic",
4369-
"IndexSelectTwoIdxModule_basic",
4370-
"IndexSelectWholeDimensionModule_basic",
4371-
"IndexSelectWholeTensorModule_basic",
43724382
"IndexTensorDyanmicInputContiguousWithNoneModule_basic",
43734383
"IndexTensorDyanmicInputNonContiguousWithNoneModule_basic",
43744384
"IndexTensorHackedTwinModule3dInput_basic",
@@ -4386,10 +4396,8 @@
43864396
"IndexTensorMultiInputOneDim_basic",
43874397
"IndexTensorMultiInputThreeIndexers_basic",
43884398
"IndexTensorMultiInput_basic",
4389-
"IndexTensorNegativeIndexModule_basic",
43904399
"IndexTensorSelectDimModule_basic",
43914400
"IndexTensorStaticContiguousWithNoneModule_basic",
4392-
"IndexTensorStaticModule_basic",
43934401
"IndexTensorStaticNonContiguousWithNoneModule_basic",
43944402
"InterpolateDynamicModule_sizes_bilinear",
43954403
"InterpolateDynamicModule_sizes_nearest",
@@ -4688,7 +4696,6 @@
46884696
"ScatterValueFloatModule_basic",
46894697
"ScatterValueIntModule_basic",
46904698
"SelectIntModule_basic",
4691-
"SelectIntNegativeDimAndIndexStaticModule_basic",
46924699
"SelectScattertModule_basic",
46934700
"SelectScattertStaticModule_basic",
46944701
"SignAndLogarithmOfDeterminantModule_F32",

0 commit comments

Comments
 (0)