Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions lib/Conversion/TorchToSCF/TorchToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,14 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern<PrimLoopOp> {
targetType = Torch::IntType::get(op->getContext());
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfForOp.getLoc(), targetType, {to});
} else if (auto tty = dyn_cast<RankedTensorType>(targetType)) {
targetType =
op.getIterArgsInit()[barg.index() - scfForOp.getNumInductionVars()]
.getType();
torchArg = typeConverter->materializeSourceConversion(
rewriter, scfForOp.getLoc(), targetType, {to});
}

if (!torchArg)
return rewriter.notifyMatchFailure(op,
"unsupported type of the operand");
Expand All @@ -289,14 +296,6 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern<PrimLoopOp> {
// Fix up the terminator.
SmallVector<Value> loopConditionIterArgs;
for (auto torchArg : primLoopConditionOp.getIterArgs()) {
Type torchType = torchArg.getType();

// If the argument is a torch tensor, directly add it in the list of
// iter args.
if (torchType.isa<Torch::BaseTensorType>()) {
loopConditionIterArgs.push_back(torchArg);
continue;
}
Value arg = typeConverter->materializeTargetConversion(
rewriter, scfForOp.getLoc(),
typeConverter->convertType(torchArg.getType()), {torchArg});
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,7 @@
"LinspaceTwoSizeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
}

STABLEHLO_CRASHING_SET = {
Expand Down Expand Up @@ -1273,6 +1274,7 @@
"LinspaceModule_basic",
"LinspaceOneSizeModule_basic",
"LinspaceTwoSizeModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic"
}

MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
Expand All @@ -1289,6 +1291,7 @@
"TensorIntModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"TorchPrimLoopForLikeTensorArgModule_basic",
}) - {
### Test failing in make_fx_tosa but not in tosa

Expand Down Expand Up @@ -1326,6 +1329,7 @@
}

LTC_XFAIL_SET = {
"TorchPrimLoopForLikeTensorArgModule_basic"
"CollapseAllDimensionsModule_basic",
"CollapseRank1DynamicModule_basic",
"CollapseStaticModule_basic",
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,26 @@ def forward(self, x):
@register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeModule())
def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils):
module.forward(tu.randint(6, 8, high=10))


# ==============================================================================

class TorchPrimLoopForLikeTensorArgModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([7,9], torch.float32, True),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
for i in range(50):
x = x + i
return x

@register_test_case(module_factory=lambda: TorchPrimLoopForLikeTensorArgModule())
def TorchPrimLoopForLikeTensorArgModule_basic(module, tu: TestUtils):
x_test = torch.zeros([7, 9]).float()

module.forward(x_test)
30 changes: 29 additions & 1 deletion test/Conversion/TorchToSCF/basic.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-scf | FileCheck %s
// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-to-scf| FileCheck %s

// CHECK-LABEL: func.func @torch.prim.if(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.bool) -> !torch.int {
Expand Down Expand Up @@ -214,3 +214,31 @@ func.func @torch.prim.Loop$for_with_multiple_results(%arg0: !torch.int) -> (!tor
} : (!torch.int, !torch.bool, !torch.float, !torch.float) -> (!torch.float, !torch.float)
return %0#0, %0#1 : !torch.float, !torch.float
}


// -----

// CHECK-LABEL: func.func @torch.prim.Loop$for_with_tensor_arg() -> !torch.vtensor<[2,3],f32> {
// CHECK: %[[LOOP_RESULT:.*]] = scf.for %[[LOOP_VARIABLE:.*]] = %[[RANGE_START:.*]] to %[[RANGE_END:.*]] step %[[RANGE_STEP:.*]] iter_args(%[[LOOP_TENSOR_ARG:.*]] = %[[LOOP_TENSOR_ARG_INIT_VAL:.*]]) -> (tensor<2x3xf32>) {
// CHECK: %[[LOOP_TENSOR_ARG_TORCH_TENSOR:.*]] = torch_c.from_builtin_tensor %[[LOOP_TENSOR_ARG]]
// CHECK: }
// CHECK: }
func.func @torch.prim.Loop$for_with_tensor_arg() -> (!torch.vtensor<[2,3],f32>) {
%true = torch.constant.bool true
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%int5 = torch.constant.int 5
%int6 = torch.constant.int 6
%none = torch.constant.none
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.zeros %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
%2 = torch.aten.ones %0, %int6, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3],f32>
%3:1 = torch.prim.Loop %int5, %true, init(%1) {
^bb0(%arg1: !torch.int, %arg2: !torch.vtensor<[2,3],f32>):
%4 = torch.aten.add.Tensor %arg2, %2, %int1 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.int -> !torch.vtensor<[2,3],f32>
torch.prim.Loop.condition %true, iter(%4 : !torch.vtensor<[2,3],f32>)
} : (!torch.int, !torch.bool, !torch.vtensor<[2,3],f32>) -> (!torch.vtensor<[2,3],f32>)
return %3#0 : !torch.vtensor<[2,3],f32>
}