Skip to content

Conversation

keshavvinayak01
Copy link
Contributor

@keshavvinayak01 keshavvinayak01 commented Oct 3, 2025

Description

Adds a new canonicalizer that folds vector.from_elements(vector.transpose)) => vector.from_elements. This canonicalization reorders the input elements for vector.from_elements, adjusts the output shape to match the effect of the transpose op and eliminating its need.

Testing

Added a 2D vector lit test that verifies the working of the rewrite.

…ctor.transpose)

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
@llvmbot
Copy link
Member

llvmbot commented Oct 3, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Keshav Vinayak Jha (keshavvinayak01)

Changes

Description

Adds a new canonicalizer that folds vector.from_elements(vector.broadcast)) => vector.from_elements. This canonicalization reorders the input elements for vector.from_elements, adjusts the output shape to match the effect of the broadcast op and eliminating its need.

Testing

Added a 2D vector lit test that verifies the working of the rewrite.


Full diff: https://github.com/llvm/llvm-project/pull/161841.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+57-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+12)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e889302f..7f6313c11ea18 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6723,6 +6723,61 @@ class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
   }
 };
 
+/// Folds transpose(from_elements(...)) into a new from_elements with permuted
+/// operands matching the transposed shape.
+class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
+public:
+  using Base::Base;
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    auto fromElementsOp =
+        transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
+    if (!fromElementsOp)
+      return failure();
+
+    VectorType srcTy = fromElementsOp.getDest().getType();
+    VectorType dstTy = transposeOp.getType();
+
+    ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+    int64_t rank = srcTy.getRank();
+
+    // Build inverse permutation to map destination indices back to source.
+    SmallVector<int64_t, 4> inversePerm(rank, 0);
+    for (int64_t i = 0; i < rank; ++i)
+      inversePerm[permutation[i]] = i;
+
+    ArrayRef<int64_t> srcShape = srcTy.getShape();
+    ArrayRef<int64_t> dstShape = dstTy.getShape();
+    SmallVector<int64_t, 4> srcIdx(rank, 0);
+    SmallVector<int64_t, 4> dstIdx(rank, 0);
+    SmallVector<int64_t, 4> srcStrides = computeStrides(srcShape);
+    SmallVector<int64_t, 4> dstStrides = computeStrides(dstShape);
+
+    auto elements = fromElementsOp.getElements();
+    SmallVector<Value> newElements;
+    int64_t dstNumElements = dstTy.getNumElements();
+    newElements.reserve(dstNumElements);
+
+    // For each element in destination row-major order, pick the corresponding
+    // source element.
+    for (int64_t lin = 0; lin < dstNumElements; ++lin) {
+      // Pick the destination element index.
+      dstIdx = delinearize(lin, dstStrides);
+      // Map the destination element index to the source element index.
+      for (int64_t j = 0; j < rank; ++j)
+        srcIdx[j] = dstIdx[inversePerm[j]];
+      // Linearize the source element index.
+      int64_t srcLin = linearize(srcIdx, srcStrides);
+      // Add the source element to the new elements.
+      newElements.push_back(elements[srcLin]);
+    }
+
+    rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
+                                                newElements);
+    return success();
+  }
+};
+
 /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
 /// 'order preserving', where 'order preserving' means the flattened
 /// inputs and outputs of the transpose have identical (numerical) values.
@@ -6823,7 +6878,8 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
 void vector::TransposeOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
-              FoldTransposeSplat, FoldTransposeBroadcast>(context);
+              FoldTransposeSplat, FoldTransposeFromElements,
+              FoldTransposeBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 5448976f84760..5f34d144cd472 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -308,6 +308,18 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x
 
 // -----
 
+// CHECK-LABEL: transpose_from_elements_2d
+func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32,
+                                      %a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> {
+  %v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32>
+  %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
+  return %t : vector<3x2xi32>
+  // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32>
+  // CHECK-NOT: vector.transpose
+}
+
+// -----
+
 func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) {
   %0 = vector.constant_mask [2, 2] : vector<4x3xi1>
   %1 = vector.extract_strided_slice %0

@keshavvinayak01 keshavvinayak01 changed the title [Vector] Added canonicalizer for folding from_elements + transpose [MLIR] [Vector] Added canonicalizer for folding from_elements + transpose Oct 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants