-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][XeGPU][Transform] Add vectorlinearize transform pass. #158084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Md Abdullah Shahneous Bari (mshahneo) ChangesUse upstream patterns to create a vectorlinearize pass needed for lowering to XeVM. Full diff: https://github.com/llvm/llvm-project/pull/158084.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index ddf6b4ac85a90..2db28a20935b1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -71,4 +71,13 @@ def XeGPUBlocking: Pass<"xegpu-blocking"> {
"index::IndexDialect"];
}
+def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> {
+ let summary = "Linearize n-D vectors to 1-D vectors";
+ let description = [{
+ This pass linearizes n-D vectors to 1-D vectors for lowering to XeVM.
+ }];
+ let dependentDialects = ["arith::ArithDialect", "memref::MemRefDialect",
+ "scf::SCFDialect", "vector::VectorDialect"];
+}
+
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
index 9c178d1d85642..e6f76067094ce 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt
@@ -5,6 +5,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
XeGPUUnroll.cpp
XeGPUWgToSgDistribute.cpp
XeGPUPropagateLayout.cpp
+ XeGPUVectorLinearize.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
new file mode 100644
index 0000000000000..a6a68716547c9
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
@@ -0,0 +1,111 @@
+//===- XeGPUVectorLinearize.cpp - Linearizes n-D vectors to 1-D vectors
+//-------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <optional>
+
+namespace mlir {
+namespace xegpu {
+#define GEN_PASS_DEF_XEGPUVECTORLINEARIZE
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
+} // namespace xegpu
+} // namespace mlir
+
+#define DEBUG_TYPE "xegpu-vector-linearize"
+
+using namespace mlir;
+
+namespace {
+struct XeGPUVectorLinearizePass final
+ : public xegpu::impl::XeGPUVectorLinearizeBase<XeGPUVectorLinearizePass> {
+ void runOnOperation() override {
+ auto *context = &getContext();
+
+ // vector.broadcast and vector.gather requires progressive lowering
+ {
+ mlir::RewritePatternSet patterns(&getContext());
+ mlir::vector::populateVectorBroadcastLoweringPatterns(patterns);
+ mlir::vector::populateVectorGatherLoweringPatterns(patterns);
+ mlir::vector::populateVectorGatherToConditionalLoadPatterns(patterns);
+ // vector.transpose lowering
+ // Shuffle16x16 will fallback to Shuffle1D for non 16x16 sizes.
+ mlir::vector::populateVectorTransposeLoweringPatterns(
+ patterns, mlir::vector::VectorTransposeLowering::Shuffle16x16);
+ (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+
+ // Unroll load store from <<MxN> to M <1xN> load/stores and then linearize
+ {
+ mlir::RewritePatternSet patterns(&getContext());
+ mlir::vector::UnrollVectorOptions vectorOptions;
+ vectorOptions.setNativeShapeFn(
+ [](mlir::Operation *op) -> std::optional<mlir::SmallVector<int64_t>> {
+ // Only unroll for vector::LoadOp and vector::StoreOp
+ if (mlir::isa<mlir::vector::LoadOp>(op)) {
+ if (auto vecType = mlir::dyn_cast<mlir::VectorType>(
+ op->getResult(0).getType())) {
+ auto shape = vecType.getShape();
+ if (shape.size() == 2)
+ return mlir::SmallVector<int64_t>{1, shape[1]};
+ }
+ }
+ if (mlir::isa<mlir::vector::StoreOp>(op)) {
+ if (auto vecType = mlir::dyn_cast<mlir::VectorType>(
+ op->getOperand(0).getType())) {
+ auto shape = vecType.getShape();
+ if (shape.size() == 2)
+ return mlir::SmallVector<int64_t>{1, shape[1]};
+ }
+ }
+ return std::nullopt;
+ });
+ mlir::vector::populateVectorUnrollPatterns(patterns, vectorOptions);
+ (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+
+ // Use upstream linearization patterns
+ {
+ mlir::MLIRContext &context = getContext();
+ mlir::TypeConverter converter;
+ mlir::RewritePatternSet patterns(&context);
+ mlir::ConversionTarget target(context);
+ mlir::vector::populateForVectorLinearize(converter, target);
+ mlir::vector::populateVectorLinearizeBasePatterns(converter, target,
+ patterns);
+ mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
+ converter, target, patterns);
+ mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
+ converter, patterns, target);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+
+ mlir::TypeConverter typeConverter;
+ mlir::RewritePatternSet patterns(context);
+ mlir::ConversionTarget target(*context);
+ typeConverter.addConversion([](mlir::Type type) { return type; });
+
+ target.addIllegalOp<mlir::vector::TransposeOp>();
+ target.addLegalOp<mlir::vector::ShapeCastOp>();
+ target.addLegalOp<mlir::vector::ExtractOp>();
+ target.addLegalDialect<mlir::xegpu::XeGPUDialect>();
+ }
+};
+} // namespace
|
Can you update the desc with why vector-to-llvm is not enough to handle it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to check if this is needed. We won't have 16x16 sizes at SIMT level now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a couple of tests in xegpu/transform?
Added :). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: best to fit this in one line :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember correctly, vector/arith linearize tests are already there. I think we only need XeGPU-specific tests here. I suggest removing duplicated tests.
Maybe we could have a few e2e lowering test for xegpu code samples like a simple GEMM with extracts/transpose etc that require the application of linearize pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestions. Fixed now, removed duplication and added some new ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for the pass to enable e2e lowering
I still miss relevant tests (or the current ones could use documentation) that would justify its existence and really ensure no regressions.
Thanks a lot, @adam-smnk , @charithaintc , @nbpatel . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good 👍
I'm sure it'll evolve further once we start running more e2e examples.
Use upstream patterns to create a vectorlinearize pass needed for lowering to xevm. Linearizes n-D vectors to 1-D vectors.
Add test case.
Update the test case to remove duplication with vector-linearize. Add new test cases for XeGPU, vector.broadcast, vector.gather.
Add vector unroll support for n-D laod/store.
c177c25
to
6b22d6d
Compare
Absolutely, not just e2e, with the proposed changes in vector dialect, this would definitely evolve :) |
Use upstream patterns to create a vectorlinearize pass needed for lowering to XeVM.
Linearizes n-D vectors to 1-D vectors.
This is needed because,
vector-to-llvm
does not linearize all the vectors.