Skip to content

Commit

Permalink
[mlir][vector] Generalize vector.transpose lowering to n-D vectors
Browse files Browse the repository at this point in the history
The existing vector.transpose lowering patterns only triggers if the
input vector is 2D. The revision extends the pattern to handle n-D
vectors which are effectively 2-D vectors (e.g., vector<1x4x1x8x1).

It refactors a common check about 2-D vectors from X86Vector
lowering to VectorUtils.h so it can be reused by both sides.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D149908
  • Loading branch information
hanhanW committed May 8, 2023
1 parent 9896d72 commit 25cc5a7
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 77 deletions.
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Expand Up @@ -37,6 +37,11 @@ namespace vector {
/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
/// the type of `source`.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);

/// Returns two dims that are greater than one if the transposition is applied
/// on a 2D slice. Otherwise, returns a failure.
FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);

} // namespace vector

/// Constructs a permutation map of invariant memref indices to vector
Expand Down
53 changes: 27 additions & 26 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Expand Up @@ -332,7 +332,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
transp.push_back(attr.cast<IntegerAttr>().getInt());

if (isShuffleLike(vectorTransformOptions.vectorTransposeLowering) &&
resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0)
succeeded(isTranspose2DSlice(op)))
return rewriter.notifyMatchFailure(
op, "Options specifies lowering to shuffle");

Expand Down Expand Up @@ -411,36 +411,37 @@ class TransposeOp2DToShuffleLowering

LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
if (!isShuffleLike(vectorTransformOptions.vectorTransposeLowering))
return rewriter.notifyMatchFailure(
op, "not using vector shuffle based lowering");

auto srcGtOneDims = isTranspose2DSlice(op);
if (failed(srcGtOneDims))
return rewriter.notifyMatchFailure(
op, "expected transposition on a 2D slice");

VectorType srcType = op.getSourceVectorType();
if (srcType.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));

SmallVector<int64_t> transp;
for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());
if (transp[0] != 1 && transp[1] != 0)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation");
// Reshape the n-D input vector with only two dimensions greater than one
// to a 2-D vector.
Location loc = op.getLoc();
auto flattenedType = VectorType::get({n * m}, srcType.getElementType());
auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
auto reshInput = rewriter.create<vector::ShapeCastOp>(loc, flattenedType,
op.getVector());

Value res;
int64_t m = srcType.getShape().front(), n = srcType.getShape().back();
switch (vectorTransformOptions.vectorTransposeLowering) {
case VectorTransposeLowering::Shuffle1D: {
Value casted = rewriter.create<vector::ShapeCastOp>(
loc, VectorType::get({m * n}, srcType.getElementType()),
op.getVector());
res = transposeToShuffle1D(rewriter, casted, m, n);
break;
}
case VectorTransposeLowering::Shuffle16x16:
if (m != 16 || n != 16)
return failure();
res = transposeToShuffle16x16(rewriter, op.getVector(), m, n);
break;
case VectorTransposeLowering::EltWise:
case VectorTransposeLowering::Flat:
return failure();
if (vectorTransformOptions.vectorTransposeLowering ==
VectorTransposeLowering::Shuffle16x16 &&
m == 16 && n == 16) {
reshInput =
rewriter.create<vector::ShapeCastOp>(loc, reshInputType, reshInput);
res = transposeToShuffle16x16(rewriter, reshInput, m, n);
} else {
// Fallback to shuffle on 1D approach.
res = transposeToShuffle1D(rewriter, reshInput, m, n);
}

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
Expand Down
57 changes: 57 additions & 0 deletions mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Expand Up @@ -43,6 +43,63 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
llvm_unreachable("Expected MemRefType or TensorType");
}

/// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
/// should be transposed with each other within the context of their 2D
/// transposition slice.
///
/// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
/// Return true: dim0 and dim1 are transposed within the context of their 2D
/// transposition slice ([1, 0]).
///
/// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
/// Return true: dim0 and dim1 are transposed within the context of their 2D
/// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
/// transposed within the full context of the transposition.
///
/// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
/// Return false: dim0 and dim1 are *not* transposed within the context of
/// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
/// and dim1 (1) are transposed within the full context of the of the
/// transposition.
static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
ArrayRef<int64_t> transp) {
// Perform a linear scan along the dimensions of the transposed pattern. If
// dim0 is found first, dim0 and dim1 are not transposed within the context of
// their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
for (int64_t permDim : transp) {
if (permDim == dim0)
return false;
if (permDim == dim1)
return true;
}

llvm_unreachable("Ill-formed transpose pattern");
}

FailureOr<std::pair<int, int>>
mlir::vector::isTranspose2DSlice(vector::TransposeOp op) {
VectorType srcType = op.getSourceVectorType();
SmallVector<int64_t> srcGtOneDims;
for (auto [index, size] : llvm::enumerate(srcType.getShape()))
if (size > 1)
srcGtOneDims.push_back(index);

if (srcGtOneDims.size() != 2)
return failure();

SmallVector<int64_t> transp;
for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());

// Check whether the two source vector dimensions that are greater than one
// must be transposed with each other so that we can apply one of the 2-D
// transpose pattens. Otherwise, these patterns are not applicable.
if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
return failure();

return std::pair<int, int>(srcGtOneDims[0], srcGtOneDims[1]);
}

/// Constructs a permutation map from memref indices to vector dimension.
///
/// The implementation uses the knowledge of the mapping of enclosing loop to
Expand Down
57 changes: 6 additions & 51 deletions mlir/lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -187,39 +188,6 @@ void mlir::x86vector::avx2::transpose8x8xf32(ImplicitLocOpBuilder &ib,
vs[7] = mm256Permute2f128Ps(ib, s3, s7, MaskHelper::permute<3, 1>());
}

/// Given the n-D transpose pattern 'transp', return true if 'dim0' and 'dim1'
/// should be transposed with each other within the context of their 2D
/// transposition slice.
///
/// Example 1: dim0 = 0, dim1 = 2, transp = [2, 1, 0]
/// Return true: dim0 and dim1 are transposed within the context of their 2D
/// transposition slice ([1, 0]).
///
/// Example 2: dim0 = 0, dim1 = 1, transp = [2, 1, 0]
/// Return true: dim0 and dim1 are transposed within the context of their 2D
/// transposition slice ([1, 0]). Paradoxically, note how dim1 (1) is *not*
/// transposed within the full context of the transposition.
///
/// Example 3: dim0 = 0, dim1 = 1, transp = [2, 0, 1]
/// Return false: dim0 and dim1 are *not* transposed within the context of
/// their 2D transposition slice ([0, 1]). Paradoxically, note how dim0 (0)
/// and dim1 (1) are transposed within the full context of the of the
/// transposition.
static bool areDimsTransposedIn2DSlice(int64_t dim0, int64_t dim1,
ArrayRef<int64_t> transp) {
// Perform a linear scan along the dimensions of the transposed pattern. If
// dim0 is found first, dim0 and dim1 are not transposed within the context of
// their 2D slice. Otherwise, 'dim1' is found first and they are transposed.
for (int64_t permDim : transp) {
if (permDim == dim0)
return false;
if (permDim == dim1)
return true;
}

llvm_unreachable("Ill-formed transpose pattern");
}

/// Rewrite AVX2-specific vector.transpose, for the supported cases and
/// depending on the `TransposeLoweringOptions`. The lowering supports 2-D
/// transpose cases and n-D cases that have been decomposed into 2-D
Expand Down Expand Up @@ -256,29 +224,16 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
if (!srcType.getElementType().isF32())
return rewriter.notifyMatchFailure(op, "Unsupported vector element type");

SmallVector<int64_t> srcGtOneDims;
for (auto [index, size] : llvm::enumerate(srcType.getShape()))
if (size > 1)
srcGtOneDims.push_back(index);

if (srcGtOneDims.size() != 2)
return rewriter.notifyMatchFailure(op, "Unsupported vector type");

SmallVector<int64_t, 4> transp;
for (auto attr : op.getTransp())
transp.push_back(attr.cast<IntegerAttr>().getInt());

// Check whether the two source vector dimensions that are greater than one
// must be transposed with each other so that we can apply one of the 2-D
// AVX2 transpose pattens. Otherwise, these patterns are not applicable.
if (!areDimsTransposedIn2DSlice(srcGtOneDims[0], srcGtOneDims[1], transp))
auto srcGtOneDims = mlir::vector::isTranspose2DSlice(op);
if (failed(srcGtOneDims))
return rewriter.notifyMatchFailure(
op, "Not applicable to this transpose permutation");
op, "expected transposition on a 2D slice");

// Retrieve the sizes of the two dimensions greater than one to be
// transposed.
auto srcShape = srcType.getShape();
int64_t m = srcShape[srcGtOneDims[0]], n = srcShape[srcGtOneDims[1]];
int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));
int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));

auto applyRewrite = [&]() {
ImplicitLocOpBuilder ib(loc, rewriter);
Expand Down
79 changes: 79 additions & 0 deletions mlir/test/Dialect/Vector/vector-transpose-lowering.mlir
Expand Up @@ -687,3 +687,82 @@ transform.sequence failures(propagate) {
lowering_strategy = "shuffle_16x16"
: (!pdl.operation) -> !pdl.operation
}

// -----

// CHECK-LABEL: func @transpose021_shuffle16x16xf32
func.func @transpose021_shuffle16x16xf32(%arg0: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 16, 1, 17, 4, 20, 5, 21, 8, 24, 9, 25, 12, 28, 13, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 18, 3, 19, 6, 22, 7, 23, 10, 26, 11, 27, 14, 30, 15, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
// CHECK: vector.shuffle {{.*}} [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
%0 = vector.transpose %arg0, [0, 2, 1] : vector<1x16x16xf32> to vector<1x16x16xf32>
return %0 : vector<1x16x16xf32>
}

transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
transform.vector.lower_transpose %module_op
lowering_strategy = "shuffle_16x16"
: (!pdl.operation) -> !pdl.operation
}
1 change: 1 addition & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Expand Up @@ -1920,6 +1920,7 @@ cc_library(
":LLVMCommonConversion",
":LLVMDialect",
":VectorDialect",
":VectorUtils",
":X86VectorDialect",
"//llvm:Core",
"//llvm:Support",
Expand Down

0 comments on commit 25cc5a7

Please sign in to comment.