From 41c3e059e1d8ad0838c80c43b242bb5943f345b8 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Fri, 10 May 2024 22:28:18 +0800 Subject: [PATCH 1/5] remove getWithLeastStaticInformation in DecomposeAtenTriuOp --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 4 +++ .../Torch/Transforms/DecomposeComplexOps.cpp | 28 +++++++++---------- lib/Dialect/Torch/Utils/Utils.cpp | 27 ++++++++++++++++++ 3 files changed, 45 insertions(+), 14 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 33a1c9f91fe..60249bfcb64 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -84,6 +84,10 @@ int64_t getNumberOfElements(RankedTensorType inputType); SmallVector makeShapeLLVMCompatible(ArrayRef shape); SmallVector makeShapeTorchCompatible(ArrayRef shape); +ValueTensorType getTensorTypeFromValueVector(ArrayRef shapes, + Type dtype); +Value getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim); + // Helper function to squeeze the input tensor at given dim. // Return the squeezed tensor or failure. FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9fad15e132f..1db5c2982fe 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -685,37 +685,37 @@ class DecomposeAtenTriuOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "the rank of tensor should >= 2"); } - auto baseType = ValueTensorType::getWithLeastStaticInformation(context); Value cstZero = rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value cstOne = rewriter.create(loc, rewriter.getI64IntegerAttr(1)); Value none = rewriter.create(loc); - Value rowDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(-2)); - Value colDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(-1)); - Value rowSize = rewriter.create(loc, input, rowDim); - Value colSize = rewriter.create(loc, input, colDim); + Value rowSize = getTensorDimSize(rewriter, input, -2); + Value colSize = getTensorDimSize(rewriter, input, -1); + auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); Value rowArange = rewriter.create( - loc, baseType, rowSize, /*dtype=*/none, /*layout=*/none, + loc, getTensorTypeFromValueVector({rowSize}, si64Type), rowSize, + /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); Value colArange = rewriter.create( - loc, baseType, colSize, /*dtype=*/none, /*layout=*/none, + loc, getTensorTypeFromValueVector({colSize}, si64Type), colSize, + /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); - Value unsqueezeRowArange = - rewriter.create(loc, baseType, rowArange, cstOne); + Value unsqueezeRowArange = unsqueezeTensor(rewriter, op, rowArange, cstOne); Value unsqueezeColArange = - rewriter.create(loc, baseType, colArange, cstZero); + unsqueezeTensor(rewriter, op, colArange, cstZero); Value unsqueezeRowArangePlusDiagonal = rewriter.create( - loc, baseType, unsqueezeRowArange, op.getDiagonal(), cstOne); + loc, unsqueezeRowArange.getType(), unsqueezeRowArange, op.getDiagonal(), + cstOne); + auto boolType = rewriter.getI1Type(); Value condTensor = rewriter.create( - loc, baseType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); + loc, getTensorTypeFromValueVector({rowSize, colSize}, boolType), + unsqueezeColArange, unsqueezeRowArangePlusDiagonal); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), condTensor, input, cstZero); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index d634556c98a..691e362d6eb 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -275,6 +275,33 @@ SmallVector Torch::makeShapeTorchCompatible(ArrayRef shape) { return updatedShape; } +ValueTensorType Torch::getTensorTypeFromValueVector(ArrayRef shapes, + Type dtype) { + assert(!shapes.empty() && "shape vector cannot be empty"); + SmallVector shapeInts; + for (Value shape : shapes) { + int64_t dim; + if (matchPattern(shape, m_TorchConstantInt(&dim))) + shapeInts.push_back(dim); + else + shapeInts.push_back(kUnknownSize); + } + return Torch::ValueTensorType::get(shapeInts[0].getContext(), shapeInts, + dtype); +} + +// Helper function to get the size of the tensor at the given dimension. +Value Torch::getTensorDimSize(PatternRewriter &rewriter, Value tensor, + int64_t dim) { + auto loc = tensor.getLoc(); + auto dimVal = + rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + // Use 'createOrFold' instead of 'create': + // If the dimension is a constant, then the AtenSizeIntOp is folded to a + // ContantIntOp. + return rewriter.createOrFold(loc, tensor, dimVal); +} + // Helper function to squeeze the input tensor at given dim. // Return the squeezed tensor or failure. FailureOr Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op, From 45657b5f8b6a61795dc4e0bec5d55d39d42f2d55 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Fri, 10 May 2024 22:48:48 +0800 Subject: [PATCH 2/5] fix bug --- .../Torch/Transforms/DecomposeComplexOps.cpp | 13 +++++++++++-- lib/Dialect/Torch/Utils/Utils.cpp | 3 +-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1db5c2982fe..507b596c62e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -704,10 +704,19 @@ class DecomposeAtenTriuOp : public OpRewritePattern { /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pin_memory=*/none); - Value unsqueezeRowArange = unsqueezeTensor(rewriter, op, rowArange, cstOne); - Value unsqueezeColArange = + auto unsqueezeRowArangeInfo = + unsqueezeTensor(rewriter, op, rowArange, cstOne); + auto unsqueezeColArangeInfo = unsqueezeTensor(rewriter, op, colArange, cstZero); + if (failed(unsqueezeRowArangeInfo) || failed(unsqueezeColArangeInfo)) { + return rewriter.notifyMatchFailure(op, + "cannot generate unsqueeze tensor"); + } + + Value unsqueezeRowArange = unsqueezeRowArangeInfo.value(); + Value unsqueezeColArange = unsqueezeColArangeInfo.value(); + Value unsqueezeRowArangePlusDiagonal = rewriter.create( loc, unsqueezeRowArange.getType(), unsqueezeRowArange, op.getDiagonal(), cstOne); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 691e362d6eb..4b98dad4074 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -286,8 +286,7 @@ ValueTensorType Torch::getTensorTypeFromValueVector(ArrayRef shapes, else shapeInts.push_back(kUnknownSize); } - return Torch::ValueTensorType::get(shapeInts[0].getContext(), shapeInts, - dtype); + return Torch::ValueTensorType::get(shapes[0].getContext(), shapeInts, dtype); } // Helper function to get the size of the tensor at the given dimension. From e74132018e14cae0b1c854fbffa27fe946cdf310 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Tue, 14 May 2024 11:18:23 +0800 Subject: [PATCH 3/5] modify code --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 2 +- .../Torch/Transforms/DecomposeComplexOps.cpp | 26 +++++++++++-------- lib/Dialect/Torch/Utils/Utils.cpp | 2 +- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 60249bfcb64..e65e9f2df90 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -84,7 +84,7 @@ int64_t getNumberOfElements(RankedTensorType inputType); SmallVector makeShapeLLVMCompatible(ArrayRef shape); SmallVector makeShapeTorchCompatible(ArrayRef shape); -ValueTensorType getTensorTypeFromValueVector(ArrayRef shapes, +ValueTensorType getResultTypeFromValueVector(ArrayRef shapes, Type dtype); Value getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 507b596c62e..a373712030a 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -694,15 +694,19 @@ class DecomposeAtenTriuOp : public OpRewritePattern { Value rowSize = getTensorDimSize(rewriter, input, -2); Value colSize = getTensorDimSize(rewriter, input, -1); - auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); - Value rowArange = rewriter.create( - loc, getTensorTypeFromValueVector({rowSize}, si64Type), rowSize, - /*dtype=*/none, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); - Value colArange = rewriter.create( - loc, getTensorTypeFromValueVector({colSize}, si64Type), colSize, - /*dtype=*/none, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); + auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true); + auto int64DtypeInt = getDtypeIntValueForType(rewriter, loc, si64Type); + auto rowArrangeType = getResultTypeFromValueVector({rowSize}, si64Type); + auto colArrangeType = getResultTypeFromValueVector({colSize}, si64Type); + + Value rowArange = + rewriter.create(loc, rowArrangeType, rowSize, + /*dtype=*/int64DtypeInt, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + Value colArange = + rewriter.create(loc, colArrangeType, colSize, + /*dtype=*/int64DtypeInt, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); auto unsqueezeRowArangeInfo = unsqueezeTensor(rewriter, op, rowArange, cstOne); @@ -722,9 +726,9 @@ class DecomposeAtenTriuOp : public OpRewritePattern { cstOne); auto boolType = rewriter.getI1Type(); + auto condType = getResultTypeFromValueVector({rowSize, colSize}, boolType); Value condTensor = rewriter.create( - loc, getTensorTypeFromValueVector({rowSize, colSize}, boolType), - unsqueezeColArange, unsqueezeRowArangePlusDiagonal); + loc, condType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), condTensor, input, cstZero); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 4b98dad4074..882a4dbc891 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -275,7 +275,7 @@ SmallVector Torch::makeShapeTorchCompatible(ArrayRef shape) { return updatedShape; } -ValueTensorType Torch::getTensorTypeFromValueVector(ArrayRef shapes, +ValueTensorType Torch::getResultTypeFromValueVector(ArrayRef shapes, Type dtype) { assert(!shapes.empty() && "shape vector cannot be empty"); SmallVector shapeInts; From bcc5314178d4284e96736f8efa8303715102cf68 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Tue, 14 May 2024 11:39:48 +0800 Subject: [PATCH 4/5] fix warning --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index a373712030a..342e3616ce9 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -674,7 +674,6 @@ class DecomposeAtenTriuOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenTriuOp op, PatternRewriter &rewriter) const override { - MLIRContext *context = op.getContext(); Location loc = op.getLoc(); Value input = op.getSelf(); auto inputType = cast(input.getType()); From 0702e21c53b4f88db03943003df42e6d51694e21 Mon Sep 17 00:00:00 2001 From: Xinyu Yang Date: Wed, 15 May 2024 13:08:09 +0800 Subject: [PATCH 5/5] rename helper function name --- include/torch-mlir/Dialect/Torch/Utils/Utils.h | 2 +- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 6 +++--- lib/Dialect/Torch/Utils/Utils.cpp | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index e65e9f2df90..b0b74c50858 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -84,7 +84,7 @@ int64_t getNumberOfElements(RankedTensorType inputType); SmallVector makeShapeLLVMCompatible(ArrayRef shape); SmallVector makeShapeTorchCompatible(ArrayRef shape); -ValueTensorType getResultTypeFromValueVector(ArrayRef shapes, +ValueTensorType getTensorTypeFromShapeValues(ArrayRef shapes, Type dtype); Value getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 342e3616ce9..fc0b68de09f 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -695,8 +695,8 @@ class DecomposeAtenTriuOp : public OpRewritePattern { auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true); auto int64DtypeInt = getDtypeIntValueForType(rewriter, loc, si64Type); - auto rowArrangeType = getResultTypeFromValueVector({rowSize}, si64Type); - auto colArrangeType = getResultTypeFromValueVector({colSize}, si64Type); + auto rowArrangeType = getTensorTypeFromShapeValues({rowSize}, si64Type); + auto colArrangeType = getTensorTypeFromShapeValues({colSize}, si64Type); Value rowArange = rewriter.create(loc, rowArrangeType, rowSize, @@ -725,7 +725,7 @@ class DecomposeAtenTriuOp : public OpRewritePattern { cstOne); auto boolType = rewriter.getI1Type(); - auto condType = getResultTypeFromValueVector({rowSize, colSize}, boolType); + auto condType = getTensorTypeFromShapeValues({rowSize, colSize}, boolType); Value condTensor = rewriter.create( loc, condType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 882a4dbc891..848a5e13344 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -275,7 +275,7 @@ SmallVector Torch::makeShapeTorchCompatible(ArrayRef shape) { return updatedShape; } -ValueTensorType Torch::getResultTypeFromValueVector(ArrayRef shapes, +ValueTensorType Torch::getTensorTypeFromShapeValues(ArrayRef shapes, Type dtype) { assert(!shapes.empty() && "shape vector cannot be empty"); SmallVector shapeInts;