-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][vector] Refine Vector to LLVM lowering options #159553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Refine Vector to LLVM lowering options #159553
Conversation
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/159553.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1a37d057776e2..aca0963478e63 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1489,8 +1489,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
VectorContractLoweringAttr.summary, [{::llvm::cl::values(
clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
"Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
- clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
- "Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
+ clEnumValN(::mlir::vector::VectorContractLowering::LLVM, "llvm",
+ "Lower directly to `llvm.intr.matrix.multiply`."),
clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
"Lower to `vector.outerproduct`."),
clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
@@ -1502,8 +1502,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
"Lower transpose into element-wise extract and inserts (default)"),
- clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
- "Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
+ clEnumValN(::mlir::vector::VectorTransposeLowering::LLVM, "llvm",
+ "Lower 2-D transpose directly to `llvm.intr.matrix.transpose`"),
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
"Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
index ef0951ab1d166..cbba44ae4dc8a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
@@ -14,10 +14,9 @@ include "mlir/IR/EnumAttr.td"
// Lower transpose into element-wise extract and inserts.
def VectorTransposeLowering_Elementwise:
I32EnumAttrCase<"EltWise", 0, "eltwise">;
-// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
-// intrinsics.
-def VectorTransposeLowering_FlatTranspose:
- I32EnumAttrCase<"Flat", 1, "flat_transpose">;
+// Lower directly to LLVM matrix intrinsics.
+def VectorTransposeLowering_LLVM:
+ I32EnumAttrCase<"LLVM", 1, "llvm">;
// Lower 2-D transpose to `vector.shuffle` on 1-D vector.
def VectorTransposeLowering_Shuffle1D:
I32EnumAttrCase<"Shuffle1D", 2, "shuffle_1d">;
@@ -27,7 +26,7 @@ def VectorTransposeLowering_Shuffle16x16:
def VectorTransposeLoweringAttr : I32EnumAttr<
"VectorTransposeLowering",
"control the lowering of `vector.transpose` operations.",
- [VectorTransposeLowering_Elementwise, VectorTransposeLowering_FlatTranspose,
+ [VectorTransposeLowering_Elementwise, VectorTransposeLowering_LLVM,
VectorTransposeLowering_Shuffle1D, VectorTransposeLowering_Shuffle16x16]> {
let cppNamespace = "::mlir::vector";
}
@@ -48,9 +47,9 @@ def VectorMultiReductionLoweringAttr: I32EnumAttr<
// Progressively lower to finer grained `vector.contract` and dot-products.
def VectorContractLowering_Dot: I32EnumAttrCase<"Dot", 0, "dot">;
-// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
-def VectorContractLowering_Matmul:
- I32EnumAttrCase<"Matmul", 1, "matmulintrinsics">;
+// Lower directly to LLVM intrinsics.
+def VectorContractLowering_LLVM:
+ I32EnumAttrCase<"LLVM", 1, "llvm">;
// Lower to `vector.outerproduct`.
def VectorContractLowering_OuterProduct:
I32EnumAttrCase<"OuterProduct", 2, "outerproduct">;
@@ -61,7 +60,7 @@ def VectorContractLowering_ParallelArith:
def VectorContractLoweringAttr: I32EnumAttr<
"VectorContractLowering",
"control the lowering of `vector.contract` operations.",
- [VectorContractLowering_Dot, VectorContractLowering_Matmul,
+ [VectorContractLowering_Dot, VectorContractLowering_LLVM,
VectorContractLowering_OuterProduct, VectorContractLowering_ParallelArith]> {
let cppNamespace = "::mlir::vector";
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index e7266740894b1..cbce7b006edd0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1987,16 +1987,12 @@ struct VectorScalableStepOpLowering
/// %e = add %c, %d
/// ```
/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
-//
-/// This only kicks in when vectorContractLowering is set to Matmul and
-/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
ContractionOpToMatmulOpLowering(
- vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 100)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
@@ -2005,23 +2001,22 @@ class ContractionOpToMatmulOpLowering
PatternRewriter &rewriter) const override;
};
-/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to:
+/// Lower a qualifying `vector.contract %a, %b, %c` (with row-major matmul
+/// semantics directly into `llvm.intr.matrix.multiply`:
+/// BEFORE:
/// ```
-/// %mta = maybe_transpose
-/// %mtb = maybe_transpose
-/// %flattened_a = vector.shape_cast %mta
-/// %flattened_b = vector.shape_cast %mtb
-/// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
-/// %mtd = vector.shape_cast %flattened_d
-/// %d = maybe_untranspose %mtd
-/// %e = add %c, %d
+/// %res = vector.contract #matmat_trait %lhs, %rhs, %acc
+/// : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
/// ```
-//
-/// This only kicks in when vectorContractLowering is set to `Matmul`.
-/// vector.transpose operations are inserted if the vector.contract op is not a
-/// row-major matrix multiply.
///
+/// AFTER:
+/// ```
+/// %lhs = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
+/// %rhs = vector.shape_cast %arg1 : vector<4x3xf32> to vector<12xf32>
+/// %matmul = llvm.intr.matrix.multiply %lhs, %rhs
+/// %res = arith.addf %acc, %matmul : vector<2x3xf32>
+/// ```
+//
/// Scalable vectors are not supported.
FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
@@ -2116,7 +2111,19 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
return res;
}
-/// Lowers vector.transpose to llvm.intr.matrix.transpose
+/// Lowers vector.transpose directly to llvm.intr.matrix.transpose
+///
+/// BEFORE:
+/// ```
+/// %tr = vector.transpose %vec, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+/// ```
+/// AFTER:
+/// ```
+/// %vec_cs = vector.shape_cast %vec : vector<2x4xf32> to vector<8xf32>
+/// %tr = llvm.intr.matrix.transpose %vec_sc
+/// {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32>
+/// %res = vector.shape_cast %tr : vector<8xf32> to vector<4x2xf32>
+/// ```
class TransposeOpToMatrixTransposeOpLowering
: public OpRewritePattern<vector::TransposeOp> {
public:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 0b44ca7ceee42..a65f8ba233b76 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -70,7 +70,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorBitCastLoweringPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
- if (vectorContractLowering == vector::VectorContractLowering::Matmul) {
+ if (vectorContractLowering == vector::VectorContractLowering::LLVM) {
// This pattern creates a dependency on the LLVM dialect, hence we don't
// include it in `populateVectorContractLoweringPatterns` that is part of
// the Vector dialect (and should not depend on LLVM).
@@ -80,7 +80,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorShapeCastLoweringPatterns(patterns);
populateVectorInterleaveLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
- if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat) {
+ if (vectorTransposeLowering == vector::VectorTransposeLowering::LLVM) {
// This pattern creates a dependency on the LLVM dialect, hence we don't
// include it in `populateVectorTransposeLoweringPatterns` that is part of
// the Vector dialect (and should not depend on LLVM).
diff --git a/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir
index 323d86ac40988..9b18d67e037db 100644
--- a/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir
@@ -13,7 +13,7 @@
// RUN: mlir-opt --convert-vector-to-llvm --dump-pass-pipeline %s 2>&1 | FileCheck %s --check-prefix=DEFAULT
-// RUN: mlir-opt --convert-vector-to-llvm='vector-contract-lowering=matmul vector-transpose-lowering=flat' \
+// RUN: mlir-opt --convert-vector-to-llvm='vector-contract-lowering=llvm vector-transpose-lowering=llvm' \
// RUN: --dump-pass-pipeline %s 2>&1 | FileCheck %s --check-prefix=NON-DEFAULT
// CHECK: builtin.module(
@@ -26,5 +26,5 @@
// CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}}
// DEFAULT: vector-contract-lowering=dot
// DEFAULT: vector-transpose-lowering=eltwise
-// NON-DEFAULT: vector-contract-lowering=matmul
-// NON-DEFAULT: vector-transpose-lowering=flat
+// NON-DEFAULT: vector-contract-lowering=llvm
+// NON-DEFAULT: vector-transpose-lowering=llvm
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 3950e54006eec..fd5bbaaadc331 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='vector-contract-lowering=matmul' | FileCheck %s
+// RUN: mlir-opt %s --convert-vector-to-llvm='vector-contract-lowering=llvm' | FileCheck %s
#matmat_accesses = [
affine_map<(i, j, k) -> (i, k)>,
diff --git a/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir b/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir
index 94689fa0dfb88..66032014d9307 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='vector-transpose-lowering=flat' --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --convert-vector-to-llvm='vector-transpose-lowering=llvm' --split-input-file | FileCheck %s
// CHECK-LABEL: func @transpose(
func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/159553.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1a37d057776e2..aca0963478e63 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1489,8 +1489,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
VectorContractLoweringAttr.summary, [{::llvm::cl::values(
clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
"Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
- clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
- "Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
+ clEnumValN(::mlir::vector::VectorContractLowering::LLVM, "llvm",
+ "Lower directly to `llvm.intr.matrix.multiply`."),
clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
"Lower to `vector.outerproduct`."),
clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
@@ -1502,8 +1502,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
"Lower transpose into element-wise extract and inserts (default)"),
- clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
- "Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
+ clEnumValN(::mlir::vector::VectorTransposeLowering::LLVM, "llvm",
+ "Lower 2-D transpose directly to `llvm.intr.matrix.transpose`"),
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
"Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
index ef0951ab1d166..cbba44ae4dc8a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td
@@ -14,10 +14,9 @@ include "mlir/IR/EnumAttr.td"
// Lower transpose into element-wise extract and inserts.
def VectorTransposeLowering_Elementwise:
I32EnumAttrCase<"EltWise", 0, "eltwise">;
-// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix
-// intrinsics.
-def VectorTransposeLowering_FlatTranspose:
- I32EnumAttrCase<"Flat", 1, "flat_transpose">;
+// Lower directly to LLVM matrix intrinsics.
+def VectorTransposeLowering_LLVM:
+ I32EnumAttrCase<"LLVM", 1, "llvm">;
// Lower 2-D transpose to `vector.shuffle` on 1-D vector.
def VectorTransposeLowering_Shuffle1D:
I32EnumAttrCase<"Shuffle1D", 2, "shuffle_1d">;
@@ -27,7 +26,7 @@ def VectorTransposeLowering_Shuffle16x16:
def VectorTransposeLoweringAttr : I32EnumAttr<
"VectorTransposeLowering",
"control the lowering of `vector.transpose` operations.",
- [VectorTransposeLowering_Elementwise, VectorTransposeLowering_FlatTranspose,
+ [VectorTransposeLowering_Elementwise, VectorTransposeLowering_LLVM,
VectorTransposeLowering_Shuffle1D, VectorTransposeLowering_Shuffle16x16]> {
let cppNamespace = "::mlir::vector";
}
@@ -48,9 +47,9 @@ def VectorMultiReductionLoweringAttr: I32EnumAttr<
// Progressively lower to finer grained `vector.contract` and dot-products.
def VectorContractLowering_Dot: I32EnumAttrCase<"Dot", 0, "dot">;
-// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics.
-def VectorContractLowering_Matmul:
- I32EnumAttrCase<"Matmul", 1, "matmulintrinsics">;
+// Lower directly to LLVM intrinsics.
+def VectorContractLowering_LLVM:
+ I32EnumAttrCase<"LLVM", 1, "llvm">;
// Lower to `vector.outerproduct`.
def VectorContractLowering_OuterProduct:
I32EnumAttrCase<"OuterProduct", 2, "outerproduct">;
@@ -61,7 +60,7 @@ def VectorContractLowering_ParallelArith:
def VectorContractLoweringAttr: I32EnumAttr<
"VectorContractLowering",
"control the lowering of `vector.contract` operations.",
- [VectorContractLowering_Dot, VectorContractLowering_Matmul,
+ [VectorContractLowering_Dot, VectorContractLowering_LLVM,
VectorContractLowering_OuterProduct, VectorContractLowering_ParallelArith]> {
let cppNamespace = "::mlir::vector";
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index e7266740894b1..cbce7b006edd0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1987,16 +1987,12 @@ struct VectorScalableStepOpLowering
/// %e = add %c, %d
/// ```
/// `vector.matrix_multiply` later lowers to `llvm.matrix.multiply`.
-//
-/// This only kicks in when vectorContractLowering is set to Matmul and
-/// the vector.contract op is a row-major matrix multiply.
class ContractionOpToMatmulOpLowering
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
public:
using MaskableOpRewritePattern::MaskableOpRewritePattern;
ContractionOpToMatmulOpLowering(
- vector::VectorContractLowering vectorContractLowering,
MLIRContext *context, PatternBenefit benefit = 100)
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit) {}
@@ -2005,23 +2001,22 @@ class ContractionOpToMatmulOpLowering
PatternRewriter &rewriter) const override;
};
-/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
-/// semantics to:
+/// Lower a qualifying `vector.contract %a, %b, %c` (with row-major matmul
+/// semantics directly into `llvm.intr.matrix.multiply`:
+/// BEFORE:
/// ```
-/// %mta = maybe_transpose
-/// %mtb = maybe_transpose
-/// %flattened_a = vector.shape_cast %mta
-/// %flattened_b = vector.shape_cast %mtb
-/// %flattened_d = llvm.intr.matrix.multiply %flattened_a, %flattened_b
-/// %mtd = vector.shape_cast %flattened_d
-/// %d = maybe_untranspose %mtd
-/// %e = add %c, %d
+/// %res = vector.contract #matmat_trait %lhs, %rhs, %acc
+/// : vector<2x4xf32>, vector<4x3xf32> into vector<2x3xf32>
/// ```
-//
-/// This only kicks in when vectorContractLowering is set to `Matmul`.
-/// vector.transpose operations are inserted if the vector.contract op is not a
-/// row-major matrix multiply.
///
+/// AFTER:
+/// ```
+/// %lhs = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32>
+/// %rhs = vector.shape_cast %arg1 : vector<4x3xf32> to vector<12xf32>
+/// %matmul = llvm.intr.matrix.multiply %lhs, %rhs
+/// %res = arith.addf %acc, %matmul : vector<2x3xf32>
+/// ```
+//
/// Scalable vectors are not supported.
FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
vector::ContractionOp op, MaskingOpInterface maskOp,
@@ -2116,7 +2111,19 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
return res;
}
-/// Lowers vector.transpose to llvm.intr.matrix.transpose
+/// Lowers vector.transpose directly to llvm.intr.matrix.transpose
+///
+/// BEFORE:
+/// ```
+/// %tr = vector.transpose %vec, [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+/// ```
+/// AFTER:
+/// ```
+/// %vec_cs = vector.shape_cast %vec : vector<2x4xf32> to vector<8xf32>
+/// %tr = llvm.intr.matrix.transpose %vec_sc
+/// {columns = 2 : i32, rows = 4 : i32} : vector<8xf32> into vector<8xf32>
+/// %res = vector.shape_cast %tr : vector<8xf32> to vector<4x2xf32>
+/// ```
class TransposeOpToMatrixTransposeOpLowering
: public OpRewritePattern<vector::TransposeOp> {
public:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 0b44ca7ceee42..a65f8ba233b76 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -70,7 +70,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorBitCastLoweringPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
- if (vectorContractLowering == vector::VectorContractLowering::Matmul) {
+ if (vectorContractLowering == vector::VectorContractLowering::LLVM) {
// This pattern creates a dependency on the LLVM dialect, hence we don't
// include it in `populateVectorContractLoweringPatterns` that is part of
// the Vector dialect (and should not depend on LLVM).
@@ -80,7 +80,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorShapeCastLoweringPatterns(patterns);
populateVectorInterleaveLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
- if (vectorTransposeLowering == vector::VectorTransposeLowering::Flat) {
+ if (vectorTransposeLowering == vector::VectorTransposeLowering::LLVM) {
// This pattern creates a dependency on the LLVM dialect, hence we don't
// include it in `populateVectorTransposeLoweringPatterns` that is part of
// the Vector dialect (and should not depend on LLVM).
diff --git a/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir
index 323d86ac40988..9b18d67e037db 100644
--- a/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir
@@ -13,7 +13,7 @@
// RUN: mlir-opt --convert-vector-to-llvm --dump-pass-pipeline %s 2>&1 | FileCheck %s --check-prefix=DEFAULT
-// RUN: mlir-opt --convert-vector-to-llvm='vector-contract-lowering=matmul vector-transpose-lowering=flat' \
+// RUN: mlir-opt --convert-vector-to-llvm='vector-contract-lowering=llvm vector-transpose-lowering=llvm' \
// RUN: --dump-pass-pipeline %s 2>&1 | FileCheck %s --check-prefix=NON-DEFAULT
// CHECK: builtin.module(
@@ -26,5 +26,5 @@
// CHECK-SAME: reassociate-fp-reductions={{[aA-zZ0-9]+}}
// DEFAULT: vector-contract-lowering=dot
// DEFAULT: vector-transpose-lowering=eltwise
-// NON-DEFAULT: vector-contract-lowering=matmul
-// NON-DEFAULT: vector-transpose-lowering=flat
+// NON-DEFAULT: vector-contract-lowering=llvm
+// NON-DEFAULT: vector-transpose-lowering=llvm
diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
index 3950e54006eec..fd5bbaaadc331 100644
--- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='vector-contract-lowering=matmul' | FileCheck %s
+// RUN: mlir-opt %s --convert-vector-to-llvm='vector-contract-lowering=llvm' | FileCheck %s
#matmat_accesses = [
affine_map<(i, j, k) -> (i, k)>,
diff --git a/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir b/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir
index 94689fa0dfb88..66032014d9307 100644
--- a/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir
+++ b/mlir/test/Dialect/Vector/vector-transpose-to-matrix-intrinsics-transform.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-vector-to-llvm='vector-transpose-lowering=flat' --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --convert-vector-to-llvm='vector-transpose-lowering=llvm' --split-input-file | FileCheck %s
// CHECK-LABEL: func @transpose(
func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> {
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
3439e0a
to
1b1bc23
Compare
This is a follow-up to llvm#144307, where we removed `vector.matrix_multiply` and `vector.flat_transpose` from the Vector dialect. This PR: * Updates comments that were missed in the previous change. * Renames relevant `-convert-vector-to-llvm=` options: - `vector-contract-lowering=matmul` → `vector-contract-lowering=llvm` - `vector-transpose-lowering=flat_transpose` → `vector-transpose-lowering=llvm` These new names better reflect the actual transformation target — LLVM intrinsics — rather than the now-removed abstract operations.
1b1bc23
to
060f098
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like you also need to update python bindings
Just had to update a Python test: commit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but let's wait for some other review
/// semantics to: | ||
/// Lower a qualifying `vector.contract %a, %b, %c` (with row-major matmul | ||
/// semantics directly into `llvm.intr.matrix.multiply`: | ||
/// BEFORE: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit/optional: you can tag code blocks with mlir, e.g.:
```mlir
%res = ...
```
also below
clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul", | ||
"Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."), | ||
clEnumValN(::mlir::vector::VectorContractLowering::LLVM, "llvm", | ||
"Lower directly to `llvm.intr.matrix.multiply`."), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llvm
only could be misleading as that's not the only option we have to lower to llvm. What about ~llvm-matmul
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point!
Since this is about lowering to "LLVM intrinsics", and the intrinsics are:
llvm.intr.matrix.multiply
(i.e. llvm.intr.<...>)llvm.intr.matrix.transpose
(i.e. llvm.intr.<...>),
let me rename this as llvmintr
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LLVM intrinsic sounds like the best name for me
clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat", | ||
"Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"), | ||
clEnumValN(::mlir::vector::VectorTransposeLowering::LLVM, "llvm", | ||
"Lower 2-D transpose directly to `llvm.intr.matrix.transpose`"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, in general, throughout the file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG, thanks!
This is a follow-up to #144307, where we removed
vector.matrix_multiply
andvector.flat_transpose
from the Vectordialect.
This PR:
-convert-vector-to-llvm=
options:vector-contract-lowering=matmul
→vector-contract-lowering=llvm
vector-transpose-lowering=flat_transpose
→vector-transpose-lowering=llvm
These new names better reflect the actual transformation target — LLVM
intrinsics — rather than the now-removed abstract operations.