Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Fold tensor.reshape for dynamic reshape #88961

Merged
merged 4 commits into from
Apr 19, 2024

Conversation

rsuderman
Copy link
Contributor

If tensor.reshape occurs with d0, d1, d2, ... for the dimensions we know that the reshape is a no-op. Checking for this case lets us fold away the computation.

If `tensor.reshape` occurs with `d0, d1, d2, ...` for the dimensions we
know that the reshape is a no-op. Checking for this case lets us fold
away the computation.
@llvmbot
Copy link
Collaborator

llvmbot commented Apr 16, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Rob Suderman (rsuderman)

Changes

If tensor.reshape occurs with d0, d1, d2, ... for the dimensions we know that the reshape is a no-op. Checking for this case lets us fold away the computation.


Full diff: https://github.com/llvm/llvm-project/pull/88961.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+42)
  • (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+47)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 0ce40e81371209..50d3cd45a2dfe9 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1580,6 +1580,48 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
           getResult().getType()))
     return reshapedSource;
+
+  auto source = getSource();
+  auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
+  auto resultTy = dyn_cast<RankedTensorType>(getType());
+
+  if (!sourceTy || !resultTy || sourceTy != resultTy)
+    return {};
+
+  if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
+    auto elements = fromElements.getElements();
+    bool dynamicNoop =
+        sourceTy.getRank() == static_cast<int64_t>(elements.size());
+    for (auto [id, element] : llvm::enumerate(elements)) {
+      APSInt cstElement;
+      if (matchPattern(element, m_ConstantInt(&cstElement))) {
+        if (cstElement.getExtValue() != sourceTy.getDimSize(id)) {
+          dynamicNoop = false;
+          break;
+        }
+        continue;
+      }
+
+      if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
+        if (dimOp.getSource() != source) {
+          dynamicNoop = false;
+          break;
+        }
+
+        APSInt dim;
+        if (!matchPattern(dimOp.getIndex(), m_ConstantInt(&dim)) ||
+            dim.getExtValue() != static_cast<int64_t>(id)) {
+          dynamicNoop = false;
+          break;
+        }
+        continue;
+      }
+    }
+
+    if (dynamicNoop)
+      return source;
+  }
+
   return {};
 }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ac365c9d297e88..751c57eacd7ae5 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2403,6 +2403,53 @@ func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xinde
 
 // -----
 
+// CHECK-LABEL: @reshape_fold_2d
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+func.func @reshape_fold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+  %ds = tensor.from_elements %d0, %d1 : tensor<2xindex>
+  %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
+  // CHECK: return %[[ARG0]]
+  return %reshape : tensor<?x?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @reshape_nofold_2d
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
+func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
+  %d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
+  %ds = tensor.from_elements %d1, %d0 : tensor<2xindex>
+  // CHECK: tensor.reshape
+  %reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
+  return %reshape : tensor<?x?xi32>
+}
+
+
+// -----
+
+// CHECK-LABEL: @reshape_fold_3d_cst
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32>
+func.func @reshape_fold_3d_cst(%arg0 : tensor<5x?x?xi32>) -> tensor<5x?x?xi32> {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %d0 = arith.constant 5 : index
+  %d1 = tensor.dim %arg0, %c1 : tensor<5x?x?xi32>
+  %d2 = tensor.dim %arg0, %c2 : tensor<5x?x?xi32>
+  %ds = tensor.from_elements %d0, %d1, %d2 : tensor<3xindex>
+  %reshape = tensor.reshape %arg0(%ds) : (tensor<5x?x?xi32>, tensor<3xindex>) -> tensor<5x?x?xi32>
+  // CHECK: return %[[ARG0]]
+  return %reshape : tensor<5x?x?xi32>
+}
+
+// -----
+
 // Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
 // CHECK-LABEL: func @dim_out_of_bounds(
 //       CHECK: %[[IDX:.*]] = index.constant 28

@hanhanW
Copy link
Contributor

hanhanW commented Apr 16, 2024

I thought we don't tensor.reshape ops... I don't have context about this operation, maybe @MaheshRavishankar ? If you need me to review the change, I can study the op semantic and review the PR.

@rsuderman
Copy link
Contributor Author

I thought we don't tensor.reshape ops... I don't have context about this operation, maybe @MaheshRavishankar ? If you need me to review the change, I can study the op semantic and review the PR.

We do have to support it eventually, namely for the dynamic case where we have a reshape with dynamic dimensions to dynamic dimensions. e.g. a reshape from tensor<?xi32> to tensor<?x?xi32> cannot be performed via an expand.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a nit comment, but otherwise looks good.

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp Outdated Show resolved Hide resolved
@rsuderman rsuderman merged commit c045955 into llvm:main Apr 19, 2024
4 checks passed
aniplcc pushed a commit to aniplcc/llvm-project that referenced this pull request Apr 21, 2024
If `tensor.reshape` occurs with `d0, d1, d2, ...` for the dimensions we
know that the reshape is a no-op. Checking for this case lets us fold
away the computation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants