Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"UnflattenStaticModule_basic",
}

TORCHDYNAMO_XFAIL_SET = {
Expand Down Expand Up @@ -1056,6 +1057,7 @@
"BatchNorm3DModule_basic",
"BatchNorm1DStaticShapeModule_basic",
"FlattenStaticModule_basic",
"UnflattenStaticModule_basic",
"FlattenRank0Module_basic",
"ElementwiseFlattenBroadcastModule_basic",
"SquareModule_basic",
Expand Down
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7537,6 +7537,30 @@ def Torch_AtenFlattenUsingIntsOp : Torch_Op<"aten.flatten.using_ints", [
}];
}

def Torch_AtenUnflattenIntOp : Torch_Op<"aten.unflatten.int", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchListOfTorchIntType:$sizes
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUnflattenIntOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenUnflattenIntOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_AtenDimOp : Torch_Op<"aten.dim", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
55 changes: 55 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2525,6 +2525,60 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenUnflattenIntOp>::matchAndRewrite(
AtenUnflattenIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

// Not a ranked tensor type
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op,
"Only ranked tensor types with static shapes are currently supported");

int64_t selfRank = selfType.getRank();
int64_t dim;

if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be a Scalar constant");

SmallVector<int64_t> sizes;
if (!matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizes)))
return rewriter.notifyMatchFailure(
op, "Only constant sizes are currently supported");

if (selfRank > 0 && !isValidDim(dim, selfRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

SmallVector<int64_t> newShape;
for (auto s :
llvm::enumerate(makeShapeTorchCompatible(selfType.getShape()))) {
int64_t idx = s.index();
if (idx < dim || idx > dim) {
newShape.push_back(s.value());
} else {
auto sum = 1;
for (auto newDims : sizes) {
newShape.push_back(newDims);
sum *= newDims;
}
if (sum != s.value())
return rewriter.notifyMatchFailure(op,
"sizes mismatch with original dim");
}
}

auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape),
selfType.getElementType());

rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
op, getTypeConverter()->convertType(newType), adaptor.getSelf(),
rewriter.getDenseI64ArrayAttr(newShape));

return success();
}

template <>
LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
AtenPermuteOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -5050,6 +5104,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
INSERT_ATENOP_PATTERN(AtenUnflattenIntOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenLog2Op);
INSERT_ATENOP_PATTERN(AtenThresholdOp);
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7205,6 +7205,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.unflatten.int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.slice.t %arg0, %none, %arg1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" %1 = torch.aten.add.t %0, %arg2 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" %2 = torch.aten.add.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %3 = torch.aten.slice.t %arg0, %2, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
" %4 = torch.aten.add.t %1, %3 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.linear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8580,6 +8590,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.unflatten.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.flip\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
8 changes: 4 additions & 4 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ bool Torch::isViewLikeOp(Operation *op) {
// that it does not return a view and treat those as having value
// semantics.
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
AtenExpandOp, AtenFlattenUsingIntsOp, AtenPermuteOp, AtenReshapeOp,
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp,
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp,
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp,
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,9 @@ def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int])
def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end_dim: int = -1) -> List[int]:
return upstream_shape_functions.flatten(self, start_dim, end_dim)

def aten〇unflatten〇int〡shape(self: List[int], dim: int, sizes: List[int]) -> List[int]:
return self[:dim] + sizes + self[dim + 1:]

def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]:
return upstream_shape_functions.linear(input, weight, bias)

Expand Down Expand Up @@ -1656,6 +1659,11 @@ def aten〇flatten〇using_ints〡dtype(self_rank_dtype: Tuple[int, int], start_
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, sizes=[1]))
def aten〇unflatten〇int〡dtype(self_rank_dtype: Tuple[int, int], dim: int, sizes: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0]))
def aten〇flip〡dtype(self_rank_dtype: Tuple[int, int], dims: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True)
emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::flatten.using_ints : (Tensor, int, int) -> (Tensor)")
emit("aten::unflatten.int : (Tensor, int, int[]) -> (Tensor)")
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
Expand Down
22 changes: 22 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,28 @@ def AddmmModule_differentRankBroadcastable(module, tu: TestUtils):
# ==============================================================================


class UnflattenStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([1, 6, 4], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.unflatten(x, 1, (2, 3))


@register_test_case(module_factory=lambda: UnflattenStaticModule())
def UnflattenStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 6, 4))


# ==============================================================================


class FlattenStaticModule(torch.nn.Module):

def __init__(self):
Expand Down
22 changes: 22 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,28 @@ func.func @forward(%arg0: !torch.vtensor<[10,3,8,9,3,4],f32> ) -> !torch.vtensor

// -----

// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,6,4],f32>) -> !torch.vtensor<[1,2,3,4],f32> {
// CHECK: %[[VAL:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,6,4],f32> -> tensor<1x6x4xf32>
// CHECK: %[[VAL_1:.*]] = torch.constant.int 1
// CHECK: %[[VAL_2:.*]] = torch.constant.int 2
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL]] {new_shape = array<i64: 1, 2, 3, 4>} : (tensor<1x6x4xf32>) -> tensor<1x2x3x4xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x3x4xf32> -> !torch.vtensor<[1,2,3,4],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,3,4],f32>
// CHECK: }
func.func @forward(%arg0: !torch.vtensor<[1,6,4],f32> ) -> !torch.vtensor<[1,2,3,4],f32> {
%int1 = torch.constant.int 1
%int2 = torch.constant.int 2
%int3 = torch.constant.int 3
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.unflatten.int %arg0, %int1, %0 : !torch.vtensor<[1,6,4],f32>, !torch.int, !torch.list<int> -> !torch.vtensor<[1,2,3,4],f32>
return %1 : !torch.vtensor<[1,2,3,4],f32>
}

// -----

// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,2,3],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,2,3],f32>,
Expand Down