diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b0132e889302f..7f6313c11ea18 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6723,6 +6723,61 @@ class FoldTransposeShapeCast final : public OpRewritePattern { } }; +/// Folds transpose(from_elements(...)) into a new from_elements with permuted +/// operands matching the transposed shape. +class FoldTransposeFromElements final : public OpRewritePattern { +public: + using Base::Base; + LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto fromElementsOp = + transposeOp.getVector().getDefiningOp(); + if (!fromElementsOp) + return failure(); + + VectorType srcTy = fromElementsOp.getDest().getType(); + VectorType dstTy = transposeOp.getType(); + + ArrayRef permutation = transposeOp.getPermutation(); + int64_t rank = srcTy.getRank(); + + // Build inverse permutation to map destination indices back to source. + SmallVector inversePerm(rank, 0); + for (int64_t i = 0; i < rank; ++i) + inversePerm[permutation[i]] = i; + + ArrayRef srcShape = srcTy.getShape(); + ArrayRef dstShape = dstTy.getShape(); + SmallVector srcIdx(rank, 0); + SmallVector dstIdx(rank, 0); + SmallVector srcStrides = computeStrides(srcShape); + SmallVector dstStrides = computeStrides(dstShape); + + auto elements = fromElementsOp.getElements(); + SmallVector newElements; + int64_t dstNumElements = dstTy.getNumElements(); + newElements.reserve(dstNumElements); + + // For each element in destination row-major order, pick the corresponding + // source element. + for (int64_t lin = 0; lin < dstNumElements; ++lin) { + // Pick the destination element index. + dstIdx = delinearize(lin, dstStrides); + // Map the destination element index to the source element index. + for (int64_t j = 0; j < rank; ++j) + srcIdx[j] = dstIdx[inversePerm[j]]; + // Linearize the source element index. + int64_t srcLin = linearize(srcIdx, srcStrides); + // Add the source element to the new elements. + newElements.push_back(elements[srcLin]); + } + + rewriter.replaceOpWithNewOp(transposeOp, dstTy, + newElements); + return success(); + } +}; + /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is /// 'order preserving', where 'order preserving' means the flattened /// inputs and outputs of the transpose have identical (numerical) values. @@ -6823,7 +6878,8 @@ class FoldTransposeBroadcast : public OpRewritePattern { void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add(context); + FoldTransposeSplat, FoldTransposeFromElements, + FoldTransposeBroadcast>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 5448976f84760..5f34d144cd472 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -308,6 +308,18 @@ func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x // ----- +// CHECK-LABEL: transpose_from_elements_2d +func.func @transpose_from_elements_2d(%a0: i32, %a1: i32, %a2: i32, + %a3: i32, %a4: i32, %a5: i32) -> vector<3x2xi32> { + %v = vector.from_elements %a0, %a1, %a2, %a3, %a4, %a5 : vector<2x3xi32> + %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to vector<3x2xi32> + return %t : vector<3x2xi32> + // CHECK: %[[R:.*]] = vector.from_elements %arg0, %arg3, %arg1, %arg4, %arg2, %arg5 : vector<3x2xi32> + // CHECK-NOT: vector.transpose +} + +// ----- + func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) { %0 = vector.constant_mask [2, 2] : vector<4x3xi1> %1 = vector.extract_strided_slice %0