Skip to content

Commit 1b1bc23

Browse files
committed
[mlir][vector] Refine Vector to LLVM lowering options
This is a follow-up to #144307, where we removed `vector.matrix_multiply` and `vector.flat_transpose` from the Vector dialect. This PR: * Updates comments that were missed in the previous change. * Renames relevant `-convert-vector-to-llvm=` options: - `vector-contract-lowering=matmul` → `vector-contract-lowering=llvm` - `vector-transpose-lowering=flat_transpose` → `vector-transpose-lowering=llvm` These new names better reflect the actual transformation target — LLVM intrinsics — rather than the now-removed abstract operations.
1 parent 573b377 commit 1b1bc23

File tree

7 files changed

+45
-39
lines changed

7 files changed

+45
-39
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,8 +1489,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14891489
VectorContractLoweringAttr.summary, [{::llvm::cl::values(
14901490
clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
14911491
"Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
1492-
clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
1493-
"Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
1492+
clEnumValN(::mlir::vector::VectorContractLowering::LLVM, "llvm",
1493+
"Lower directly to `llvm.intr.matrix.multiply`."),
14941494
clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
14951495
"Lower to `vector.outerproduct`."),
14961496
clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
@@ -1502,8 +1502,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
15021502
VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
15031503
clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
15041504
"Lower transpose into element-wise extract and inserts (default)"),
1505-
clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
1506-
"Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
1505+
clEnumValN(::mlir::vector::VectorTransposeLowering::LLVM, "llvm",
1506+
"Lower 2-D transpose directly to `llvm.intr.matrix.transpose`"),
15071507
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
15081508
"Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
15091509
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",

mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@ include "mlir/IR/EnumAttr.td"
1414
// Lower transpose into element-wise extract and inserts.
1515
def VectorTransposeLowering_Elementwise:
1616
I32EnumAttrCase<"EltWise", 0, "eltwise">;
17-
// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
18-
// intrinsics.
19-
def VectorTransposeLowering_FlatTranspose:
20-
I32EnumAttrCase<"Flat", 1, "flat_transpose">;
17+
// Lower directly to LLVM matrix intrinsics.
18+
def VectorTransposeLowering_LLVM:
19+
I32EnumAttrCase<"LLVM", 1, "llvm">;
2120
// Lower 2-D transpose to `vector.shuffle` on 1-D vector.
2221
def VectorTransposeLowering_Shuffle1D:
2322
I32EnumAttrCase<"Shuffle1D", 2, "shuffle_1d">;
@@ -27,7 +26,7 @@ def VectorTransposeLowering_Shuffle16x16:
2726
def VectorTransposeLoweringAttr : I32EnumAttr<
2827
"VectorTransposeLowering",
2928
"control the lowering of `vector.transpose` operations.",
30-
[VectorTransposeLowering_Elementwise, VectorTransposeLowering_FlatTranspose,
29+
[VectorTransposeLowering_Elementwise, VectorTransposeLowering_LLVM,
3130
VectorTransposeLowering_Shuffle1D, VectorTransposeLowering_Shuffle16x16]> {
3231
let cppNamespace = "::mlir::vector";
3332
}
@@ -48,9 +47,9 @@ def VectorMultiReductionLoweringAttr: I32EnumAttr<
4847

4948
// Progressively lower to finer grained `vector.contract` and dot-products.
5049
def VectorContractLowering_Dot: I32EnumAttrCase<"Dot", 0, "dot">;
51-
// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
52-
def VectorContractLowering_Matmul:
53-
I32EnumAttrCase<"Matmul", 1, "matmulintrinsics">;
50+
// Lower directly to LLVM intrinsics.
51+
def VectorContractLowering_LLVM:
52+
I32EnumAttrCase<"LLVM", 1, "llvm">;
5453
// Lower to `vector.outerproduct`.
5554
def VectorContractLowering_OuterProduct:
5655
I32EnumAttrCase<"OuterProduct", 2, "outerproduct">;
@@ -61,7 +60,7 @@ def VectorContractLowering_ParallelArith:
6160
def VectorContractLoweringAttr: I32EnumAttr<
6261
"VectorContractLowering",
6362
"control the lowering of `vector.contract` operations.",
64-
[VectorContractLowering_Dot, VectorContractLowering_Matmul,
63+
[VectorContractLowering_Dot, VectorContractLowering_LLVM,
6564
VectorContractLowering_OuterProduct, VectorContractLowering_ParallelArith]> {
6665
let cppNamespace = "::mlir::vector";
6766
}

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1987,16 +1987,12 @@ struct VectorScalableStepOpLowering
19871987
/// %e = add %c, %d
19881988
/// ```
19891989
/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
1990-
//
1991-
/// This only kicks in when vectorContractLowering is set to Matmul and
1992-
/// the vector.contract op is a row-major matrix multiply.
19931990
class ContractionOpToMatmulOpLowering
19941991
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
19951992
public:
19961993
using MaskableOpRewritePattern::MaskableOpRewritePattern;
19971994

19981995
ContractionOpToMatmulOpLowering(
1999-
vector::VectorContractLowering vectorContractLowering,
20001996
MLIRContext *context, PatternBenefit benefit = 100)
20011997
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
20021998

@@ -2005,23 +2001,22 @@ class ContractionOpToMatmulOpLowering
20052001
PatternRewriter &rewriter) const override;
20062002
};
20072003

2008-
/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
2009-
/// semantics to:
2004+
/// Lower a qualifying `vector.contract %a, %b, %c` (with row-major matmul
2005+
/// semantics directly into `llvm.intr.matrix.multiply`:
2006+
/// BEFORE:
20102007
/// ```
2011-
/// %mta = maybe_transpose
2012-
/// %mtb = maybe_transpose
2013-
/// %flattened_a = vector.shape_cast %mta
2014-
/// %flattened_b = vector.shape_cast %mtb
2015-
/// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
2016-
/// %mtd = vector.shape_cast %flattened_d
2017-
/// %d = maybe_untranspose %mtd
2018-
/// %e = add %c, %d
2008+
/// %res = vector.contract #matmat_trait %lhs, %rhs, %acc
2009+
/// : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
20192010
/// ```
2020-
//
2021-
/// This only kicks in when vectorContractLowering is set to `Matmul`.
2022-
/// vector.transpose operations are inserted if the vector.contract op is not a
2023-
/// row-major matrix multiply.
20242011
///
2012+
/// AFTER:
2013+
/// ```
2014+
/// %lhs = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
2015+
/// %rhs = vector.shape_cast %arg1 : vector<4x3xf32> to vector<12xf32>
2016+
/// %matmul = llvm.intr.matrix.multiply %lhs, %rhs
2017+
/// %res = arith.addf %acc, %matmul : vector<2x3xf32>
2018+
/// ```
2019+
//
20252020
/// Scalable vectors are not supported.
20262021
FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
20272022
vector::ContractionOp op, MaskingOpInterface maskOp,
@@ -2116,7 +2111,19 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
21162111
return res;
21172112
}
21182113

2119-
/// Lowers vector.transpose to llvm.intr.matrix.transpose
2114+
/// Lowers vector.transpose directly to llvm.intr.matrix.transpose
2115+
///
2116+
/// BEFORE:
2117+
/// ```
2118+
/// %tr = vector.transpose %vec, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
2119+
/// ```
2120+
/// AFTER:
2121+
/// ```
2122+
/// %vec_cs = vector.shape_cast %vec : vector<2x4xf32> to vector<8xf32>
2123+
/// %tr = llvm.intr.matrix.transpose %vec_sc
2124+
/// {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32>
2125+
/// %res = vector.shape_cast %tr : vector<8xf32> to vector<4x2xf32>
2126+
/// ```
21202127
class TransposeOpToMatrixTransposeOpLowering
21212128
: public OpRewritePattern<vector::TransposeOp> {
21222129
public:

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
7070
populateVectorBitCastLoweringPatterns(patterns);
7171
populateVectorBroadcastLoweringPatterns(patterns);
7272
populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
73-
if (vectorContractLowering == vector::VectorContractLowering::Matmul) {
73+
if (vectorContractLowering == vector::VectorContractLowering::LLVM) {
7474
// This pattern creates a dependency on the LLVM dialect, hence we don't
7575
// include it in `populateVectorContractLoweringPatterns` that is part of
7676
// the Vector dialect (and should not depend on LLVM).
@@ -80,7 +80,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
8080
populateVectorShapeCastLoweringPatterns(patterns);
8181
populateVectorInterleaveLoweringPatterns(patterns);
8282
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
83-
if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat) {
83+
if (vectorTransposeLowering == vector::VectorTransposeLowering::LLVM) {
8484
// This pattern creates a dependency on the LLVM dialect, hence we don't
8585
// include it in `populateVectorTransposeLoweringPatterns` that is part of
8686
// the Vector dialect (and should not depend on LLVM).

mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
// RUN: mlir-opt --convert-vector-to-llvm --dump-pass-pipeline %s 2>&1 | FileCheck %s --check-prefix=DEFAULT
1515

16-
// RUN: mlir-opt --convert-vector-to-llvm='vector-contract-lowering=matmul vector-transpose-lowering=flat' \
16+
// RUN: mlir-opt --convert-vector-to-llvm='vector-contract-lowering=llvm vector-transpose-lowering=llvm' \
1717
// RUN: --dump-pass-pipeline %s 2>&1 | FileCheck %s --check-prefix=NON-DEFAULT
1818

1919
// CHECK: builtin.module(
@@ -26,5 +26,5 @@
2626
// CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}}
2727
// DEFAULT: vector-contract-lowering=dot
2828
// DEFAULT: vector-transpose-lowering=eltwise
29-
// NON-DEFAULT: vector-contract-lowering=matmul
30-
// NON-DEFAULT: vector-transpose-lowering=flat
29+
// NON-DEFAULT: vector-contract-lowering=llvm
30+
// NON-DEFAULT: vector-transpose-lowering=llvm

mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --convert-vector-to-llvm='vector-contract-lowering=matmul' | FileCheck %s
1+
// RUN: mlir-opt %s --convert-vector-to-llvm='vector-contract-lowering=llvm' | FileCheck %s
22

33
#matmat_accesses = [
44
affine_map<(i, j, k) -> (i, k)>,

mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --convert-vector-to-llvm='vector-transpose-lowering=flat' --split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s --convert-vector-to-llvm='vector-transpose-lowering=llvm' --split-input-file | FileCheck %s
22

33
// CHECK-LABEL: func @transpose(
44
func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {

0 commit comments

Comments
 (0)