Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,62 @@ class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
};
} // namespace

namespace {
class InferTensorOp : public OpRewritePattern<AtenTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTensorOp op,
PatternRewriter &rewriter) const override {
auto context = op.getContext();
auto loc = op.getLoc();
auto result = op.getResult();
auto resultType = cast<BaseTensorType>(result.getType());
if (resultType.hasSizes() && resultType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "The result of aten.tensor is already a BaseTensorType.");
}

auto inputList = op.getOperand(0);
auto listConstruct = inputList.getDefiningOp<PrimListConstructOp>();
if (!listConstruct) {
return rewriter.notifyMatchFailure(
op, "The operand 0 of aten.tensor is not PrimListConstructOp.");
}

// Currently only support the 1d input list.
SmallVector<int64_t> sizes;
sizes.push_back(listConstruct->getOperands().size());
FailureOr<Type> torchType;
auto eleType = listConstruct->getOperands()[0].getType();
if (isa<Torch::IntType>(eleType)) {
torchType = getTypeForScalarType(op->getContext(),
torch_upstream::ScalarType::Long);
} else if (isa<Torch::FloatType>(eleType)) {
torchType = getTypeForScalarType(op->getContext(),
torch_upstream::ScalarType::Float);
} else {
return rewriter.notifyMatchFailure(
op, "Currently only support Int and Float Type.");
}
auto newResultType = ValueTensorType::get(context, sizes, *torchType);

Value originalTypedValue;
for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) {
if (!originalTypedValue) {
rewriter.setInsertionPointAfter(op);
originalTypedValue =
rewriter.create<TensorStaticInfoCastOp>(loc, resultType, result);
}
use.set(originalTypedValue);
}

result.setType(newResultType);

return success();
}
};
} // namespace

static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
int resultNum,
PatternRewriter &rewriter) {
Expand Down Expand Up @@ -135,6 +191,7 @@ class SimplifyShapeCalculationsPass
populateFoldPrimUncheckedCastOpPattern(patterns, context);
patterns.insert<DecomposeAtenSizeOp>(context);
patterns.insert<RefineShapeCalculateOp>(context);
patterns.insert<InferTensorOp>(context);

PrimIfOp::getCanonicalizationPatterns(patterns, context);
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
Expand Down
24 changes: 24 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5621,6 +5621,30 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils):
# ==============================================================================


class TensorAlloc1dStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 4, 6], torch.int, True),
]
)
def forward(self, x):
res = torch.tensor([x.shape[0]])
return res


@register_test_case(module_factory=lambda: TensorAlloc1dStaticModule())
def TensorAlloc1dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 6))


# ==============================================================================


class ScalarTensorFloat32Module(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading