diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index fd9827772547..cce4aa8a6350 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -17,6 +17,7 @@ # Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "UnflattenStaticModule_basic", } TORCHDYNAMO_XFAIL_SET = { @@ -1056,6 +1057,7 @@ "BatchNorm3DModule_basic", "BatchNorm1DStaticShapeModule_basic", "FlattenStaticModule_basic", + "UnflattenStaticModule_basic", "FlattenRank0Module_basic", "ElementwiseFlattenBroadcastModule_basic", "SquareModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f1338142d197..4f4fa561fbba 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7537,6 +7537,30 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [ }]; } +def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchListOfTorchIntType:$sizes + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUnflattenIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenUnflattenIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenDimOp : Torch_Op<"aten.dim", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1e71f51b8598..d2adefc4d3c7 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2525,6 +2525,60 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenUnflattenIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + // Not a ranked tensor type + auto selfType = adaptor.getSelf().getType().dyn_cast(); + if (!selfType || !selfType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, + "Only ranked tensor types with static shapes are currently supported"); + + int64_t selfRank = selfType.getRank(); + int64_t dim; + + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant"); + + SmallVector sizes; + if (!matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizes))) + return rewriter.notifyMatchFailure( + op, "Only constant sizes are currently supported"); + + if (selfRank > 0 && !isValidDim(dim, selfRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + + SmallVector newShape; + for (auto s : + llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) { + int64_t idx = s.index(); + if (idx < dim || idx > dim) { + newShape.push_back(s.value()); + } else { + auto sum = 1; + for (auto newDims : sizes) { + newShape.push_back(newDims); + sum *= newDims; + } + if (sum != s.value()) + return rewriter.notifyMatchFailure(op, + "sizes mismatch with original dim"); + } + } + + auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape), + selfType.getElementType()); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(newType), adaptor.getSelf(), + rewriter.getDenseI64ArrayAttr(newShape)); + + return success(); +} + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenPermuteOp op, OpAdaptor adaptor, @@ -5050,6 +5104,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); + INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); INSERT_ATENOP_PATTERN(AtenPermuteOp); INSERT_ATENOP_PATTERN(AtenLog2Op); INSERT_ATENOP_PATTERN(AtenThresholdOp); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index e8f5aa568f59..513d7b018d46 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7205,6 +7205,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.unflatten.int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.slice.t %arg0, %none, %arg1, %int1 : !torch.list, !torch.none, !torch.int, !torch.int -> !torch.list\n" +" %1 = torch.aten.add.t %0, %arg2 : !torch.list, !torch.list -> !torch.list\n" +" %2 = torch.aten.add.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.slice.t %arg0, %2, %none, %int1 : !torch.list, !torch.int, !torch.none, !torch.int -> !torch.list\n" +" %4 = torch.aten.add.t %1, %3 : !torch.list, !torch.list -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.linear\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list, !torch.list, !torch.optional>) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8580,6 +8590,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.unflatten.int\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.flip\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 5de777763ea5..ddc95bd4b2fd 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -199,10 +199,10 @@ bool Torch::isViewLikeOp(Operation *op) { // that it does not return a view and treat those as having value // semantics. return isa List[int]: return upstream_shape_functions.flatten(self, start_dim, end_dim) +def aten〇unflatten〇int〡shape(self: List[int], dim: int, sizes: List[int]) -> List[int]: + return self[:dim] + sizes + self[dim + 1:] + def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]: return upstream_shape_functions.linear(input, weight, bias) @@ -1656,6 +1659,11 @@ def aten〇flatten〇using_ints〡dtype(self_rank_dtype: Tuple[int, int], start_ self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, sizes=[1])) +def aten〇unflatten〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, sizes: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index f540a1ad2a7d..3916f313620b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -516,6 +516,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)") + emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)") emit("aten::dim : (Tensor) -> (int)", has_folder=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::Bool.Tensor : (Tensor) -> (bool)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 1b5f62715a30..e0269e68ce33 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -304,6 +304,28 @@ def AddmmModule_differentRankBroadcastable(module, tu: TestUtils): # ============================================================================== +class UnflattenStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 6, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.unflatten(x, 1, (2, 3)) + + +@register_test_case(module_factory=lambda: UnflattenStaticModule()) +def UnflattenStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 6, 4)) + + +# ============================================================================== + + class FlattenStaticModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 49907a98a56d..dc4e4793a67d 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -556,6 +556,28 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor // ----- +// CHECK-LABEL: func.func @forward( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,6,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> { +// CHECK: %[[VAL:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,6,4],f32> -> tensor<1x6x4xf32> +// CHECK: %[[VAL_1:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL]] {new_shape = array} : (tensor<1x6x4xf32>) -> tensor<1x2x3x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,3,4],f32> +// CHECK: } +func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3,4],f32> { + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[1,6,4],f32>, !torch.int, !torch.list -> !torch.vtensor<[1,2,3,4],f32> + return %1 : !torch.vtensor<[1,2,3,4],f32> +} + +// ----- + // CHECK-LABEL: func.func @forward( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>,