Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] eliminate "getWithLeastStaticInformation" in DecomposeAtenTriuOp #3330

Merged
merged 7 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ int64_t getNumberOfElements(RankedTensorType inputType);
SmallVector<int64_t> makeShapeLLVMCompatible(ArrayRef<int64_t> shape);
SmallVector<int64_t> makeShapeTorchCompatible(ArrayRef<int64_t> shape);

ValueTensorType getTensorTypeFromValueVector(ArrayRef<Value> shapes,
Type dtype);
Value getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim);

// Helper function to squeeze the input tensor at given dim.
// Return the squeezed tensor or failure.
FailureOr<Value> squeezeTensor(PatternRewriter &rewriter, Operation *op,
Expand Down
39 changes: 24 additions & 15 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,37 +685,46 @@ class DecomposeAtenTriuOp : public OpRewritePattern<AtenTriuOp> {
return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2");
}

auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
Value cstZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value none = rewriter.create<ConstantNoneOp>(loc);

Value rowDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-2));
Value colDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
Value rowSize = rewriter.create<AtenSizeIntOp>(loc, input, rowDim);
Value colSize = rewriter.create<AtenSizeIntOp>(loc, input, colDim);
Value rowSize = getTensorDimSize(rewriter, input, -2);
Value colSize = getTensorDimSize(rewriter, input, -1);

auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
Value rowArange = rewriter.create<AtenArangeOp>(
loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none,
loc, getTensorTypeFromValueVector({rowSize}, si64Type), rowSize,
/*dtype=*/none, /*layout=*/none,
Xinyu302 marked this conversation as resolved.
Show resolved Hide resolved
/*device=*/none, /*pin_memory=*/none);
Value colArange = rewriter.create<AtenArangeOp>(
loc, baseType, colSize, /*dtype=*/none, /*layout=*/none,
loc, getTensorTypeFromValueVector({colSize}, si64Type), colSize,
/*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);

Value unsqueezeRowArange =
rewriter.create<AtenUnsqueezeOp>(loc, baseType, rowArange, cstOne);
Value unsqueezeColArange =
rewriter.create<AtenUnsqueezeOp>(loc, baseType, colArange, cstZero);
auto unsqueezeRowArangeInfo =
unsqueezeTensor(rewriter, op, rowArange, cstOne);
auto unsqueezeColArangeInfo =
unsqueezeTensor(rewriter, op, colArange, cstZero);

if (failed(unsqueezeRowArangeInfo) || failed(unsqueezeColArangeInfo)) {
return rewriter.notifyMatchFailure(op,
"cannot generate unsqueeze tensor");
}

Value unsqueezeRowArange = unsqueezeRowArangeInfo.value();
Value unsqueezeColArange = unsqueezeColArangeInfo.value();

Value unsqueezeRowArangePlusDiagonal = rewriter.create<AtenAddScalarOp>(
loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne);
loc, unsqueezeRowArange.getType(), unsqueezeRowArange, op.getDiagonal(),
cstOne);

auto boolType = rewriter.getI1Type();
Value condTensor = rewriter.create<AtenGeTensorOp>(
loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal);
loc, getTensorTypeFromValueVector({rowSize, colSize}, boolType),
unsqueezeColArange, unsqueezeRowArangePlusDiagonal);

rewriter.replaceOpWithNewOp<AtenWhereScalarOtherOp>(
op, op.getResult().getType(), condTensor, input, cstZero);
Expand Down
26 changes: 26 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,32 @@ SmallVector<int64_t> Torch::makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
return updatedShape;
}

ValueTensorType Torch::getTensorTypeFromValueVector(ArrayRef<Value> shapes,
Type dtype) {
assert(!shapes.empty() && "shape vector cannot be empty");
SmallVector<int64_t> shapeInts;
for (Value shape : shapes) {
int64_t dim;
if (matchPattern(shape, m_TorchConstantInt(&dim)))
shapeInts.push_back(dim);
else
shapeInts.push_back(kUnknownSize);
}
return Torch::ValueTensorType::get(shapes[0].getContext(), shapeInts, dtype);
}

// Helper function to get the size of the tensor at the given dimension.
Value Torch::getTensorDimSize(PatternRewriter &rewriter, Value tensor,
int64_t dim) {
auto loc = tensor.getLoc();
auto dimVal =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(dim));
// Use 'createOrFold' instead of 'create':
// If the dimension is a constant, then the AtenSizeIntOp is folded to a
// ContantIntOp.
return rewriter.createOrFold<AtenSizeIntOp>(loc, tensor, dimVal);
}

// Helper function to squeeze the input tensor at given dim.
// Return the squeezed tensor or failure.
FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,
Expand Down