From c94bbb7846d0885a63d01b07ed7c8e362fd49689 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Fri, 3 Oct 2025 05:52:21 -0700 Subject: [PATCH 1/2] Added canonicalization (vector.from_elements + vector.transpose -> vector.transpose) Signed-off-by: Keshav Vinayak Jha --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 61 +++++++++++++++++++++- mlir/test/Dialect/Vector/canonicalize.mlir | 12 +++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b0132e889302f..31246f5da49b1 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2499,6 +2499,7 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, return DenseElementsAttr::get(destVecType, convertedElements); } + OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { if (auto res = foldFromElementsToElements(*this)) return res; @@ -6723,6 +6724,63 @@ class FoldTransposeShapeCast final : public OpRewritePattern { } }; +/// Folds transpose(from_elements(...)) into a new from_elements with permuted +/// operands matching the transposed shape. +class FoldTransposeFromElements final + : public OpRewritePattern { +public: + +using Base::Base; + LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto fromElementsOp = + transposeOp.getVector().getDefiningOp(); + if (!fromElementsOp) + return failure(); + + VectorType srcTy = fromElementsOp.getDest().getType(); + VectorType dstTy = transposeOp.getType(); + + ArrayRef permutation = transposeOp.getPermutation(); + int64_t rank = srcTy.getRank(); + + // Build inverse permutation to map destination indices back to source. + SmallVector inversePerm(rank, 0); + for (int64_t i = 0; i < rank; ++i) + inversePerm[permutation[i]] = i; + + ArrayRef srcShape = srcTy.getShape(); + ArrayRef dstShape = dstTy.getShape(); + SmallVector srcIdx(rank, 0); + SmallVector dstIdx(rank, 0); + SmallVector srcStrides = computeStrides(srcShape); + SmallVector dstStrides = computeStrides(dstShape); + + auto elements = fromElementsOp.getElements(); + SmallVector newElements; + int64_t dstNumElements = dstTy.getNumElements(); + newElements.reserve(dstNumElements); + + // For each element in destination row-major order, pick the corresponding + // source element. + for (int64_t lin = 0; lin < dstNumElements; ++lin) { + // Pick the destination element index. + dstIdx = delinearize(lin, dstStrides); + // Map the destination element index to the source element index. + for (int64_t j = 0; j < rank; ++j) + srcIdx[j] = dstIdx[inversePerm[j]]; + // Linearize the source element index. + int64_t srcLin = linearize(srcIdx, srcStrides); + // Add the source element to the new elements. + newElements.push_back(elements[srcLin]); + } + + rewriter.replaceOpWithNewOp(transposeOp, dstTy, + newElements); + return success(); + } +}; + /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is /// 'order preserving', where 'order preserving' means the flattened /// inputs and outputs of the transpose have identical (numerical) values. @@ -6823,7 +6881,8 @@ class FoldTransposeBroadcast : public OpRewritePattern { void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add(context); + FoldTransposeSplat, FoldTransposeFromElements, + FoldTransposeBroadcast>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 5448976f84760..5f34d144cd472 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -308,6 +308,18 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x // ----- +// CHECK-LABEL: transpose_from_elements_2d +func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32, + %a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> { + %v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32> + %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32> + return %t : vector<3x2xi32> + // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32> + // CHECK-NOT: vector.transpose +} + +// ----- + func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) { %0 = vector.constant_mask [2, 2] : vector<4x3xi1> %1 = vector.extract_strided_slice %0 From 6bef6d259f8abf82d48092eae1404d6a2ebbfac7 Mon Sep 17 00:00:00 2001 From: Keshav Vinayak Jha Date: Fri, 3 Oct 2025 05:54:15 -0700 Subject: [PATCH 2/2] Formatted Signed-off-by: Keshav Vinayak Jha --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 31246f5da49b1..7f6313c11ea18 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2499,7 +2499,6 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, return DenseElementsAttr::get(destVecType, convertedElements); } - OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) { if (auto res = foldFromElementsToElements(*this)) return res; @@ -6726,11 +6725,9 @@ class FoldTransposeShapeCast final : public OpRewritePattern { /// Folds transpose(from_elements(...)) into a new from_elements with permuted /// operands matching the transposed shape. -class FoldTransposeFromElements final - : public OpRewritePattern { +class FoldTransposeFromElements final : public OpRewritePattern { public: - -using Base::Base; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { auto fromElementsOp = @@ -6776,7 +6773,7 @@ using Base::Base; } rewriter.replaceOpWithNewOp(transposeOp, dstTy, - newElements); + newElements); return success(); } };