From 99325af0c4ee18d25e94461eeccaa1eba5f4389c Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Thu, 2 Oct 2025 12:13:15 -0700 Subject: [PATCH 1/5] Decompose AtenUpsampleNearestVecOps to interpolate Signed-off-by: zjgarvey --- .../Torch/Transforms/DecomposeComplexOps.cpp | 30 +++++++ .../Transforms/LowerToBackendContract.cpp | 2 + .../torch_mlir_e2e_test/test_suite/conv.py | 83 +++++++++++++++++++ 3 files changed, 115 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d34f6dfcaff4..cc36ceeb953b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5185,6 +5185,30 @@ class DecomposeAtenUnflattenIntOp }; } // namespace +namespace { +template +class DecomposeAtenUpsampleNearestVecOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(UpsampleVecOp op, + PatternRewriter &rewriter) const override { + Value scales = op.getScaleFactors(); + static_assert(std::is_same_v || + std::is_same_v); + Value cstMode = rewriter.create( + op.getLoc(), rewriter.getStringAttr("nearest")); + Value cstNone = rewriter.create(op.getLoc()); + Value cstAntialias = + rewriter.create(op.getLoc(), false); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getInput(), op.getOutputSize(), + op.getScaleFactors(), cstMode, cstNone, cstNone, cstAntialias); + return success(); + } +}; +} // namespace + // Decompose aten.expand into aten.broadcast_to op. namespace { class DecomposeAtenExpandOp : public OpRewritePattern { @@ -12983,6 +13007,12 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenUpsampleNearestVecOp>( + patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenUpsampleNearestVecOp>( + patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index d746862193aa..cfc8bb96118b 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -593,6 +593,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 2ec87b9fee43..eaeb17a0a711 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1088,6 +1088,89 @@ def UpSampleNearest2dStaticFactor_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 4)) +class UpSampleNearest2dVecNoneShape(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float64, True), + ] + ) + def forward(self, input): + return torch.ops.aten.upsample_nearest2d.vec( + input, output_size=None, scale_factors=[3.66, 4.2] + ) + +@register_test_case(module_factory=lambda: UpSampleNearest2dVecNoneShape()) +def UpSampleNearest2dVecNoneShape_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 6, 12).to(torch.float64)) + + +class UpSampleNearest2dVecNoneScales(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.float64, True), + ] + ) + def forward(self, input): + return torch.ops.aten.upsample_nearest2d.vec( + input, output_size=[18, 48], scale_factors=None, + ) + +@register_test_case(module_factory=lambda: UpSampleNearest2dVecNoneScales()) +def UpSampleNearest2dVecNoneScales_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 6, 12).to(torch.float64)) + + +class UpSampleNearest1dVecNoneShape(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) + def forward(self, input): + return torch.ops.aten.upsample_nearest1d.vec( + input, output_size=None, scale_factors=[3.0] + ) + +@register_test_case(module_factory=lambda: UpSampleNearest1dVecNoneShape()) +def UpSampleNearest1dVecNoneShape_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 6).to(torch.float64)) + +class UpSampleNearest1dVecNoneScales(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) + def forward(self, input): + return torch.ops.aten.upsample_nearest1d.vec( + input, [18], None + ) + +@register_test_case(module_factory=lambda: UpSampleNearest1dVecNoneScales()) +def UpSampleNearest1dVecNoneScales_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 6).to(torch.float64)) + + class Conv1dModule(torch.nn.Module): def __init__(self): super().__init__() From c80833dece1ebf4e050874fec099b902bf8221c9 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Thu, 2 Oct 2025 12:20:54 -0700 Subject: [PATCH 2/5] lint Signed-off-by: zjgarvey --- .../python/torch_mlir_e2e_test/test_suite/conv.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index eaeb17a0a711..b9dc855b7c0a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1104,6 +1104,7 @@ def forward(self, input): input, output_size=None, scale_factors=[3.66, 4.2] ) + @register_test_case(module_factory=lambda: UpSampleNearest2dVecNoneShape()) def UpSampleNearest2dVecNoneShape_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 6, 12).to(torch.float64)) @@ -1122,9 +1123,12 @@ def __init__(self): ) def forward(self, input): return torch.ops.aten.upsample_nearest2d.vec( - input, output_size=[18, 48], scale_factors=None, + input, + output_size=[18, 48], + scale_factors=None, ) + @register_test_case(module_factory=lambda: UpSampleNearest2dVecNoneScales()) def UpSampleNearest2dVecNoneScales_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 6, 12).to(torch.float64)) @@ -1146,10 +1150,12 @@ def forward(self, input): input, output_size=None, scale_factors=[3.0] ) + @register_test_case(module_factory=lambda: UpSampleNearest1dVecNoneShape()) def UpSampleNearest1dVecNoneShape_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 6).to(torch.float64)) + class UpSampleNearest1dVecNoneScales(torch.nn.Module): def __init__(self): super().__init__() @@ -1162,9 +1168,8 @@ def __init__(self): ] ) def forward(self, input): - return torch.ops.aten.upsample_nearest1d.vec( - input, [18], None - ) + return torch.ops.aten.upsample_nearest1d.vec(input, [18], None) + @register_test_case(module_factory=lambda: UpSampleNearest1dVecNoneScales()) def UpSampleNearest1dVecNoneScales_basic(module, tu: TestUtils): From 41d741dd34cd908f8f22c99fb4323d68785fe0d6 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Fri, 3 Oct 2025 09:04:30 -0700 Subject: [PATCH 3/5] Unfail newly passing test. Signed-off-by: zjgarvey --- projects/pt1/e2e_testing/xfail_sets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e4a2e319d7fe..c98bfc72d24f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -497,7 +497,6 @@ "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", "IsInfiniteModule_basic", - "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", From b858c79379d99f47b272ba39331e3e29c154be06 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Fri, 3 Oct 2025 09:18:59 -0700 Subject: [PATCH 4/5] Add upsample tests to fx_importer_stablehlo xfailsets Signed-off-by: zjgarvey --- projects/pt1/e2e_testing/xfail_sets.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c98bfc72d24f..685dc57aa8ed 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -914,8 +914,12 @@ "TraceUnsignedIntModule_empty", "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "UpSampleNearest1dVecNoneScales_basic", + "UpSampleNearest1dVecNoneShape_basic", "UpSampleNearest2dBackwardScalesNone_basic", "UpSampleNearest2dBackward_basic", + "UpSampleNearest2dVecNoneScales_basic", + "UpSampleNearest2dVecNoneShape_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", # Error: `aten.as_strided` op is not supported From b9fa64f9123437234bbf2585643505e9a8b8daa3 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Fri, 3 Oct 2025 14:52:17 -0700 Subject: [PATCH 5/5] Add to tosa xfails Signed-off-by: zjgarvey --- projects/pt1/e2e_testing/xfail_sets.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 685dc57aa8ed..7e048f2b0143 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3959,8 +3959,13 @@ "TransposedConv2dNegativePadding_basic", "TransposedConv3dNegativePadding_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "InterpolateDynamicModule_sizes_nearest", + "UpSampleNearest1dVecNoneScales_basic", + "UpSampleNearest1dVecNoneShape_basic", "UpSampleNearest2dBackwardScalesNone_basic", "UpSampleNearest2dBackward_basic", + "UpSampleNearest2dVecNoneScales_basic", + "UpSampleNearest2dVecNoneShape_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", "VisionTransformerModule_basic",