Skip to content

Commit 288cd5e

Browse files
authored
fix: remove split.Tensor and split_with_sizes from decomp tables (#4340)
More details are in the issue: #4339 Since `split.Tensor` and `split_with_sizes` decompose to `as_strided` (see issue above why `as_strided` is problematic), we remove them from the decomp tables. cc @zjgarvey --------- Signed-off-by: raayandhar <rdhar@amd.com>
1 parent 52dbb8d commit 288cd5e

File tree

4 files changed

+1
-25
lines changed

4 files changed

+1
-25
lines changed

lib/Dialect/Torch/Transforms/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ void mlir::torch::Torch::createTorchDynamoExportToTorchBackendPipeline(
7676
if (options.decompose) {
7777
pm.addNestedPass<func::FuncOp>(
7878
Torch::createDecomposeComplexOpsPass(options.backendLegalOps));
79+
pm.addNestedPass<func::FuncOp>(Torch::createRecomposeComplexOpsPass());
7980
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
8081
}
8182
}

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -922,17 +922,6 @@
922922
"UpSampleNearest2dVecNoneShape_basic",
923923
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
924924
"ViewSizeFromOtherTensor_basic",
925-
# Error: `aten.as_strided` op is not supported
926-
"ChunkListUnpackDynamic_Module_basic",
927-
"ChunkListUnpackUnevenDynamic_Module_basic",
928-
"ChunkListUnpackUneven_Module_basic",
929-
"ChunkListUnpack_Module_basic",
930-
"SplitTensorGetItem_Module_basic",
931-
"SplitTensorLastSmallerModule_basic",
932-
"SplitTensorListUnpackModule_basic",
933-
"SplitTensorNegativeDimModule_basic",
934-
"SplitWithSizesListUnpackModule_basic",
935-
"SplitWithSizes_Module_basic",
936925
"Unfold_Module_basic",
937926
"Unfold_Module_Rank_4",
938927
"Unfold_Module_Rank_Zero_basic",
@@ -4018,17 +4007,7 @@
40184007
"AtenAsStridedModule_basic",
40194008
"AtenAsStridedNoStorageOffsetModule_basic",
40204009
"AtenAsStridedUnknownSizeModule_basic",
4021-
"ChunkListUnpackDynamic_Module_basic",
4022-
"ChunkListUnpackUnevenDynamic_Module_basic",
4023-
"ChunkListUnpackUneven_Module_basic",
4024-
"ChunkListUnpack_Module_basic",
40254010
"NativeGroupNormModule_basic",
4026-
"SplitTensorGetItem_Module_basic",
4027-
"SplitTensorLastSmallerModule_basic",
4028-
"SplitTensorListUnpackModule_basic",
4029-
"SplitTensorNegativeDimModule_basic",
4030-
"SplitWithSizesListUnpackModule_basic",
4031-
"SplitWithSizes_Module_basic",
40324011
# error: argument must be a memref of f32, f64, i32, i64, i8, i1, c32, c64, but got 'memref<3x5xbf16>'
40334012
"ElementwiseClampMaxModule_bfloat16",
40344013
"ElementwiseClampMinModule_bfloat16",

projects/pt1/python/torch_mlir/dynamo.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ def _get_decomposition_table():
5252
# support for aten.native_batch_norm_backward.
5353
aten._native_batch_norm_legit_functional,
5454
aten.native_group_norm,
55-
aten.split.Tensor,
56-
aten.split_with_sizes,
5755
aten.norm.ScalarOpt_dim,
5856
aten.embedding_dense_backward,
5957
aten.native_layer_norm_backward,

python/torch_mlir/extras/fx_decomp_util.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
torch.ops.aten.norm.ScalarOpt_dim,
1111
torch.ops.aten.native_group_norm,
1212
torch.ops.aten.upsample_bilinear2d.vec,
13-
torch.ops.aten.split.Tensor,
14-
torch.ops.aten.split_with_sizes,
1513
torch.ops.aten.native_layer_norm,
1614
torch.ops.aten.masked_fill.Tensor,
1715
torch.ops.aten.masked_fill.Scalar,

0 commit comments

Comments
 (0)