Skip to content

Commit

Permalink
[MLIR] Add view support in case of single dynamic dim
Browse files Browse the repository at this point in the history
This commit adds the lowering of `aten.view` op when input as well as
output both have exactly one dynamic dimension.

Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
  • Loading branch information
Shukla-Gaurav committed Sep 22, 2022
1 parent 4ef6e69 commit 18424b5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
17 changes: 17 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,23 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
}
}

if (llvm::count(outputShape, kUnknownSize) == 1 &&
llvm::count(inputShape, kUnknownSize) == 1) {
SmallVector<ReassociationIndices> inputAssociations;
SmallVector<ReassociationIndices> outputAssociations;
inputAssociations.emplace_back();
outputAssociations.emplace_back();
for (int i = 0; i < inputRank; i++)
inputAssociations.back().push_back(i);
for (int i = 0; i < resultRank; i++)
outputAssociations.back().push_back(i);
Value collapsedTensor = rewriter.create<tensor::CollapseShapeOp>(
op->getLoc(), input, inputAssociations);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultType, collapsedTensor, outputAssociations);
return success();
}

// Mark the end of the input/output shapes
unchangedDims.emplace_back();
unchangedDims.back().push_back(inputRank);
Expand Down
19 changes: 19 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,3 +561,22 @@ def forward(self, a):
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))

# ==============================================================================

class ViewExpandCollapseDynamicDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, 16, 128], torch.float32, True),
])

def forward(self, a):
return a.view(16, 1, -1)

@register_test_case(module_factory=lambda: ViewExpandCollapseDynamicDimModule())
def ViewExpandCollapseDynamicDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 16, 128))

0 comments on commit 18424b5

Please sign in to comment.