diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 96e14f0fdd6e..208dcefcc85f 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -274,7 +274,14 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { targetType = Torch::IntType::get(op->getContext()); torchArg = typeConverter->materializeSourceConversion( rewriter, scfForOp.getLoc(), targetType, {to}); + } else if (auto tty = dyn_cast(targetType)) { + targetType = + op.getIterArgsInit()[barg.index() - scfForOp.getNumInductionVars()] + .getType(); + torchArg = typeConverter->materializeSourceConversion( + rewriter, scfForOp.getLoc(), targetType, {to}); } + if (!torchArg) return rewriter.notifyMatchFailure(op, "unsupported type of the operand"); @@ -289,14 +296,6 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { // Fix up the terminator. SmallVector loopConditionIterArgs; for (auto torchArg : primLoopConditionOp.getIterArgs()) { - Type torchType = torchArg.getType(); - - // If the argument is a torch tensor, directly add it in the list of - // iter args. - if (torchType.isa()) { - loopConditionIterArgs.push_back(torchArg); - continue; - } Value arg = typeConverter->materializeTargetConversion( rewriter, scfForOp.getLoc(), typeConverter->convertType(torchArg.getType()), {torchArg}); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6c6666a28619..576c33cc4212 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -851,6 +851,7 @@ "LinspaceTwoSizeModule_basic", "FakeQuantizePerTensorAffineModule_basic", "FakeQuantizePerTensorAffineRoundToEvenModule_basic", + "TorchPrimLoopForLikeTensorArgModule_basic", } STABLEHLO_CRASHING_SET = { @@ -1273,6 +1274,7 @@ "LinspaceModule_basic", "LinspaceOneSizeModule_basic", "LinspaceTwoSizeModule_basic", + "TorchPrimLoopForLikeTensorArgModule_basic" } MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { @@ -1289,6 +1291,7 @@ "TensorIntModule_basic", "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "TorchPrimLoopForLikeTensorArgModule_basic", }) - { ### Test failing in make_fx_tosa but not in tosa @@ -1326,6 +1329,7 @@ } LTC_XFAIL_SET = { + "TorchPrimLoopForLikeTensorArgModule_basic" "CollapseAllDimensionsModule_basic", "CollapseRank1DynamicModule_basic", "CollapseStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py index 6f8240f54d89..d40a77bb6dbf 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py @@ -55,3 +55,26 @@ def forward(self, x): @register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeModule()) def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils): module.forward(tu.randint(6, 8, high=10)) + + +# ============================================================================== + +class TorchPrimLoopForLikeTensorArgModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([7,9], torch.float32, True), + ]) + def forward(self, x: torch.Tensor) -> torch.Tensor: + for i in range(50): + x = x + i + return x + +@register_test_case(module_factory=lambda: TorchPrimLoopForLikeTensorArgModule()) +def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils): + x_test = torch.zeros([7, 9]).float() + + module.forward(x_test) diff --git a/test/Conversion/TorchToSCF/basic.mlir b/test/Conversion/TorchToSCF/basic.mlir index fadac3b4f97d..fa4f46f044ca 100644 --- a/test/Conversion/TorchToSCF/basic.mlir +++ b/test/Conversion/TorchToSCF/basic.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-scf | FileCheck %s +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-to-scf| FileCheck %s // CHECK-LABEL: func.func @torch.prim.if( // CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int { @@ -214,3 +214,31 @@ func.func @torch.prim.Loop$for_with_multiple_results(%arg0: !torch.int) -> (!tor } : (!torch.int, !torch.bool, !torch.float, !torch.float) -> (!torch.float, !torch.float) return %0#0, %0#1 : !torch.float, !torch.float } + + +// ----- + +// CHECK-LABEL: func.func @torch.prim.Loop$for_with_tensor_arg() -> !torch.vtensor<[2,3],f32> { +// CHECK: %[[LOOP_RESULT:.*]] = scf.for %[[LOOP_VARIABLE:.*]] = %[[RANGE_START:.*]] to %[[RANGE_END:.*]] step %[[RANGE_STEP:.*]] iter_args(%[[LOOP_TENSOR_ARG:.*]] = %[[LOOP_TENSOR_ARG_INIT_VAL:.*]]) -> (tensor<2x3xf32>) { +// CHECK: %[[LOOP_TENSOR_ARG_TORCH_TENSOR:.*]] = torch_c.from_builtin_tensor %[[LOOP_TENSOR_ARG]] +// CHECK: } +// CHECK: } +func.func @torch.prim.Loop$for_with_tensor_arg() -> (!torch.vtensor<[2,3],f32>) { + %true = torch.constant.bool true + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %int5 = torch.constant.int 5 + %int6 = torch.constant.int 6 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32> + %2 = torch.aten.ones %0, %int6, %none, %none, %none : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32> + %3:1 = torch.prim.Loop %int5, %true, init(%1) { + ^bb0(%arg1: !torch.int, %arg2: !torch.vtensor<[2,3],f32>): + %4 = torch.aten.add.Tensor %arg2, %2, %int1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32> + torch.prim.Loop.condition %true, iter(%4 : !torch.vtensor<[2,3],f32>) + } : (!torch.int, !torch.bool, !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>) + return %3#0 : !torch.vtensor<[2,3],f32> +}