-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Fold tensor.reshape
for dynamic reshape
#88961
Conversation
If `tensor.reshape` occurs with `d0, d1, d2, ...` for the dimensions we know that the reshape is a no-op. Checking for this case lets us fold away the computation.
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Rob Suderman (rsuderman) ChangesIf Full diff: https://github.com/llvm/llvm-project/pull/88961.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0ce40e81371209..50d3cd45a2dfe9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1580,6 +1580,48 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getResult().getType()))
return reshapedSource;
+
+ auto source = getSource();
+ auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
+ auto resultTy = dyn_cast<RankedTensorType>(getType());
+
+ if (!sourceTy || !resultTy || sourceTy != resultTy)
+ return {};
+
+ if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
+ auto elements = fromElements.getElements();
+ bool dynamicNoop =
+ sourceTy.getRank() == static_cast<int64_t>(elements.size());
+ for (auto [id, element] : llvm::enumerate(elements)) {
+ APSInt cstElement;
+ if (matchPattern(element, m_ConstantInt(&cstElement))) {
+ if (cstElement.getExtValue() != sourceTy.getDimSize(id)) {
+ dynamicNoop = false;
+ break;
+ }
+ continue;
+ }
+
+ if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
+ if (dimOp.getSource() != source) {
+ dynamicNoop = false;
+ break;
+ }
+
+ APSInt dim;
+ if (!matchPattern(dimOp.getIndex(), m_ConstantInt(&dim)) ||
+ dim.getExtValue() != static_cast<int64_t>(id)) {
+ dynamicNoop = false;
+ break;
+ }
+ continue;
+ }
+ }
+
+ if (dynamicNoop)
+ return source;
+ }
+
return {};
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ac365c9d297e88..751c57eacd7ae5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2403,6 +2403,53 @@ func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xinde
// -----
+// CHECK-LABEL: @reshape_fold_2d
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+func.func @reshape_fold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+ %ds = tensor.from_elements %d0, %d1 : tensor<2xindex>
+ %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
+ // CHECK: return %[[ARG0]]
+ return %reshape : tensor<?x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reshape_nofold_2d
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+ %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+ %ds = tensor.from_elements %d1, %d0 : tensor<2xindex>
+ // CHECK: tensor.reshape
+ %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
+ return %reshape : tensor<?x?xi32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @reshape_fold_3d_cst
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32>
+func.func @reshape_fold_3d_cst(%arg0 : tensor<5x?x?xi32>) -> tensor<5x?x?xi32> {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %d0 = arith.constant 5 : index
+ %d1 = tensor.dim %arg0, %c1 : tensor<5x?x?xi32>
+ %d2 = tensor.dim %arg0, %c2 : tensor<5x?x?xi32>
+ %ds = tensor.from_elements %d0, %d1, %d2 : tensor<3xindex>
+ %reshape = tensor.reshape %arg0(%ds) : (tensor<5x?x?xi32>, tensor<3xindex>) -> tensor<5x?x?xi32>
+ // CHECK: return %[[ARG0]]
+ return %reshape : tensor<5x?x?xi32>
+}
+
+// -----
+
// Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
// CHECK-LABEL: func @dim_out_of_bounds(
// CHECK: %[[IDX:.*]] = index.constant 28
|
I thought we don't |
We do have to support it eventually, namely for the dynamic case where we have a reshape with dynamic dimensions to dynamic dimensions. e.g. a reshape from |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a nit comment, but otherwise looks good.
If `tensor.reshape` occurs with `d0, d1, d2, ...` for the dimensions we know that the reshape is a no-op. Checking for this case lets us fold away the computation.
If
tensor.reshape
occurs withd0, d1, d2, ...
for the dimensions we know that the reshape is a no-op. Checking for this case lets us fold away the computation.