diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h index 748248d45df26..633950e8d54dc 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h @@ -14,6 +14,9 @@ #define MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES_H #include "mlir/Pass/Pass.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/StringRef.h" namespace mlir { @@ -45,6 +48,14 @@ namespace memref { #define GEN_PASS_DECL #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" +/// Additional construction for FoldMemrefAliasOps to allow disabling +/// patterns by name, and controlling folding via a callback function. +/// `controlFn(Operation* userOp)` will be passed the user operation of the +/// aliasing op (e.g., a load/store that uses the result of a memref.subview). +std::unique_ptr createFoldMemRefAliasOpsPass( + ArrayRef excludedPatterns, + function_ref controlFn = nullptr); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index 8b76930aed35a..91d84d1c1d9ff 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -21,6 +21,7 @@ namespace mlir { class OpBuilder; class RewritePatternSet; class RewriterBase; +class Operation; class Value; class ValueRange; class ReifyRankedShapedTypeOpInterface; @@ -43,8 +44,13 @@ class DeallocOp; void populateExpandOpsPatterns(RewritePatternSet &patterns); /// Appends patterns for folding memref aliasing ops into consumer load/store -/// ops into `patterns`. -void populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns); +/// ops into `patterns`. If `controlFn` is provided, each pattern invokes it and +/// bails out when it returns false. +/// `controlFn(Operation* userOp)` will be passed the user operation of the +/// aliasing op (e.g., a load/store that uses the result of a memref.subview). +void populateFoldMemRefAliasOpPatterns( + RewritePatternSet &patterns, + function_ref controlFn = nullptr); /// Appends patterns that resolve `memref.dim` operations with values that are /// defined by operations that implement the diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 3cacb7e29263b..50cfebbdc66ce 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -21,11 +21,14 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineMap.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include +#include #define DEBUG_TYPE "fold-memref-alias-ops" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -82,11 +85,29 @@ static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) { //===----------------------------------------------------------------------===// namespace { +using ControlFunction = std::function; + +template +class FoldMemRefAliasPattern : public OpRewritePattern { +public: + FoldMemRefAliasPattern(MLIRContext *context, + ControlFunction controlFn = ControlFunction()) + : OpRewritePattern(context), controlFn(std::move(controlFn)) {} + +protected: + bool shouldRewrite(Operation *op) const { + return !controlFn || controlFn(op); + } + +private: + ControlFunction controlFn; +}; + /// Merges subview operation with load/transferRead operation. template -class LoadOpOfSubViewOpFolder final : public OpRewritePattern { +class LoadOpOfSubViewOpFolder final : public FoldMemRefAliasPattern { public: - using OpRewritePattern::OpRewritePattern; + using FoldMemRefAliasPattern::FoldMemRefAliasPattern; LogicalResult matchAndRewrite(OpTy loadOp, PatternRewriter &rewriter) const override; @@ -94,9 +115,9 @@ class LoadOpOfSubViewOpFolder final : public OpRewritePattern { /// Merges expand_shape operation with load/transferRead operation. template -class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern { +class LoadOpOfExpandShapeOpFolder final : public FoldMemRefAliasPattern { public: - using OpRewritePattern::OpRewritePattern; + using FoldMemRefAliasPattern::FoldMemRefAliasPattern; LogicalResult matchAndRewrite(OpTy loadOp, PatternRewriter &rewriter) const override; @@ -104,9 +125,10 @@ class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern { /// Merges collapse_shape operation with load/transferRead operation. template -class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern { +class LoadOpOfCollapseShapeOpFolder final + : public FoldMemRefAliasPattern { public: - using OpRewritePattern::OpRewritePattern; + using FoldMemRefAliasPattern::FoldMemRefAliasPattern; LogicalResult matchAndRewrite(OpTy loadOp, PatternRewriter &rewriter) const override; @@ -114,9 +136,9 @@ class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern { /// Merges subview operation with store/transferWriteOp operation. template -class StoreOpOfSubViewOpFolder final : public OpRewritePattern { +class StoreOpOfSubViewOpFolder final : public FoldMemRefAliasPattern { public: - using OpRewritePattern::OpRewritePattern; + using FoldMemRefAliasPattern::FoldMemRefAliasPattern; LogicalResult matchAndRewrite(OpTy storeOp, PatternRewriter &rewriter) const override; @@ -124,9 +146,9 @@ class StoreOpOfSubViewOpFolder final : public OpRewritePattern { /// Merges expand_shape operation with store/transferWriteOp operation. template -class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern { +class StoreOpOfExpandShapeOpFolder final : public FoldMemRefAliasPattern { public: - using OpRewritePattern::OpRewritePattern; + using FoldMemRefAliasPattern::FoldMemRefAliasPattern; LogicalResult matchAndRewrite(OpTy storeOp, PatternRewriter &rewriter) const override; @@ -134,21 +156,26 @@ class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern { /// Merges collapse_shape operation with store/transferWriteOp operation. template -class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern { +class StoreOpOfCollapseShapeOpFolder final + : public FoldMemRefAliasPattern { public: - using OpRewritePattern::OpRewritePattern; + using FoldMemRefAliasPattern::FoldMemRefAliasPattern; LogicalResult matchAndRewrite(OpTy storeOp, PatternRewriter &rewriter) const override; }; /// Folds subview(subview(x)) to a single subview(x). -class SubViewOfSubViewFolder : public OpRewritePattern { +class SubViewOfSubViewFolder + : public FoldMemRefAliasPattern { public: - using OpRewritePattern::OpRewritePattern; + using FoldMemRefAliasPattern::FoldMemRefAliasPattern; LogicalResult matchAndRewrite(memref::SubViewOp subView, PatternRewriter &rewriter) const override { + if (!this->shouldRewrite(subView)) + return failure(); + auto srcSubView = subView.getSource().getDefiningOp(); if (!srcSubView) return failure(); @@ -188,9 +215,10 @@ class SubViewOfSubViewFolder : public OpRewritePattern { /// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern /// is folds subview on src and dst memref of the copy. class NVGPUAsyncCopyOpSubViewOpFolder final - : public OpRewritePattern { + : public FoldMemRefAliasPattern { public: - using OpRewritePattern::OpRewritePattern; + using FoldMemRefAliasPattern< + nvgpu::DeviceAsyncCopyOp>::FoldMemRefAliasPattern; LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const override; @@ -234,6 +262,8 @@ static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter, template LogicalResult LoadOpOfSubViewOpFolder::matchAndRewrite( OpTy loadOp, PatternRewriter &rewriter) const { + if (!this->shouldRewrite(loadOp)) + return failure(); auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp(); @@ -290,6 +320,8 @@ LogicalResult LoadOpOfSubViewOpFolder::matchAndRewrite( template LogicalResult LoadOpOfExpandShapeOpFolder::matchAndRewrite( OpTy loadOp, PatternRewriter &rewriter) const { + if (!this->shouldRewrite(loadOp)) + return failure(); auto expandShapeOp = getMemRefOperand(loadOp).template getDefiningOp(); @@ -351,6 +383,8 @@ LogicalResult LoadOpOfExpandShapeOpFolder::matchAndRewrite( template LogicalResult LoadOpOfCollapseShapeOpFolder::matchAndRewrite( OpTy loadOp, PatternRewriter &rewriter) const { + if (!this->shouldRewrite(loadOp)) + return failure(); auto collapseShapeOp = getMemRefOperand(loadOp) .template getDefiningOp(); @@ -383,6 +417,8 @@ LogicalResult LoadOpOfCollapseShapeOpFolder::matchAndRewrite( template LogicalResult StoreOpOfSubViewOpFolder::matchAndRewrite( OpTy storeOp, PatternRewriter &rewriter) const { + if (!this->shouldRewrite(storeOp)) + return failure(); auto subViewOp = getMemRefOperand(storeOp).template getDefiningOp(); @@ -435,6 +471,8 @@ LogicalResult StoreOpOfSubViewOpFolder::matchAndRewrite( template LogicalResult StoreOpOfExpandShapeOpFolder::matchAndRewrite( OpTy storeOp, PatternRewriter &rewriter) const { + if (!this->shouldRewrite(storeOp)) + return failure(); auto expandShapeOp = getMemRefOperand(storeOp).template getDefiningOp(); @@ -470,6 +508,8 @@ LogicalResult StoreOpOfExpandShapeOpFolder::matchAndRewrite( template LogicalResult StoreOpOfCollapseShapeOpFolder::matchAndRewrite( OpTy storeOp, PatternRewriter &rewriter) const { + if (!this->shouldRewrite(storeOp)) + return failure(); auto collapseShapeOp = getMemRefOperand(storeOp) .template getDefiningOp(); @@ -501,6 +541,8 @@ LogicalResult StoreOpOfCollapseShapeOpFolder::matchAndRewrite( LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite( nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const { + if (!this->shouldRewrite(copyOp)) + return failure(); LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n"); @@ -550,7 +592,12 @@ LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite( return success(); } -void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { +void memref::populateFoldMemRefAliasOpPatterns( + RewritePatternSet &patterns, function_ref controlFn) { + ControlFunction controlFnStorage; + if (controlFn) + controlFnStorage = controlFn; + patterns.add, LoadOpOfSubViewOpFolder, LoadOpOfSubViewOpFolder, @@ -576,7 +623,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { StoreOpOfCollapseShapeOpFolder, StoreOpOfCollapseShapeOpFolder, SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>( - patterns.getContext()); + patterns.getContext(), controlFnStorage); } //===----------------------------------------------------------------------===// @@ -587,13 +634,37 @@ namespace { struct FoldMemRefAliasOpsPass final : public memref::impl::FoldMemRefAliasOpsPassBase { + FoldMemRefAliasOpsPass() = default; + FoldMemRefAliasOpsPass(ArrayRef disabledPatterns, + function_ref controlFn = nullptr) + : disabledPatternNames(disabledPatterns.begin(), disabledPatterns.end()) { + if (controlFn) + controlFunction = controlFn; + } + void runOnOperation() override; + +private: + SmallVector disabledPatternNames; + ControlFunction controlFunction; }; } // namespace void FoldMemRefAliasOpsPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - memref::populateFoldMemRefAliasOpPatterns(patterns); - (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + RewritePatternSet owningPatterns(&getContext()); + function_ref controlFnRef; + if (controlFunction) + controlFnRef = controlFunction; + memref::populateFoldMemRefAliasOpPatterns(owningPatterns, controlFnRef); + + FrozenRewritePatternSet patterns(std::move(owningPatterns), + disabledPatternNames); + (void)applyPatternsGreedily(getOperation(), patterns); +} + +std::unique_ptr mlir::memref::createFoldMemRefAliasOpsPass( + ArrayRef excludedPatterns, + function_ref controlFn) { + return std::make_unique(excludedPatterns, controlFn); } diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops-options.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops-options.mlir new file mode 100644 index 0000000000000..bd4edd9ed65eb --- /dev/null +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops-options.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt --test-fold-memref-alias-options="exclude-pattern=load-subview" -split-input-file %s | FileCheck %s --check-prefix=EXCLUDE +// RUN: mlir-opt --test-fold-memref-alias-options="control-attr=no_fold" -split-input-file %s | FileCheck %s --check-prefix=CONTROL + +// ----- + +// Excluding the load-subview pattern keeps the subview + load untouched. +func.func @exclude_load_subview(%arg0: memref<4xf32>) -> f32 { + %c0 = arith.constant 0 : index + %sv = memref.subview %arg0[0] [4] [1] : memref<4xf32> to memref<4xf32, strided<[1], offset: 0>> + %v = memref.load %sv[%c0] : memref<4xf32, strided<[1], offset: 0>> + return %v : f32 +} + +// EXCLUDE-LABEL: func.func @exclude_load_subview +// EXCLUDE: %[[SV:.*]] = memref.subview +// EXCLUDE: memref.load %[[SV]] +// EXCLUDE-NOT: memref.load %arg0 + +// ----- + +// Control callback rejects ops carrying the attribute; the plain load is still +// folded through the subview. +func.func @control_attr(%arg0: memref<4xf32>) -> (f32, f32) { + %c0 = arith.constant 0 : index + %sv = memref.subview %arg0[0] [4] [1] : memref<4xf32> to memref<4xf32, strided<[1], offset: 0>> + %blocked = memref.load %sv[%c0] {no_fold} : memref<4xf32, strided<[1], offset: 0>> + %folded = memref.load %sv[%c0] : memref<4xf32, strided<[1], offset: 0>> + return %blocked, %folded : f32, f32 +} + +// CONTROL-LABEL: func.func @control_attr +// CONTROL: %[[SV:.*]] = memref.subview +// CONTROL: %[[A:.*]] = memref.load %[[SV]][%c0] {no_fold} +// CONTROL: %[[B:.*]] = memref.load %arg0[%c0] diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt index 39457ab2d0bf7..4a707f719a317 100644 --- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRMemRefTestPasses TestComposeSubView.cpp TestEmulateNarrowType.cpp + TestFoldMemRefAliasOptions.cpp TestMultiBuffer.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/MemRef/TestFoldMemRefAliasOptions.cpp b/mlir/test/lib/Dialect/MemRef/TestFoldMemRefAliasOptions.cpp new file mode 100644 index 0000000000000..aa4d366262ffa --- /dev/null +++ b/mlir/test/lib/Dialect/MemRef/TestFoldMemRefAliasOptions.cpp @@ -0,0 +1,101 @@ +//===- TestFoldMemRefAliasOptions.cpp - Test FoldMemRefAlias options ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a test pass to exercise the optional arguments of +// FoldMemRefAliasOps (excluded patterns and control callback). +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/TypeName.h" + +using namespace mlir; + +namespace { +struct TestFoldMemRefAliasOptionsPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFoldMemRefAliasOptionsPass) + + TestFoldMemRefAliasOptionsPass() = default; + TestFoldMemRefAliasOptionsPass(const TestFoldMemRefAliasOptionsPass &pass) + : PassWrapper(pass) {} + + StringRef getArgument() const final { + return "test-fold-memref-alias-options"; + } + StringRef getDescription() const final { + return "Test FoldMemRefAliasOps optional arguments"; + } + + ListOption excludedPatternTokens{ + *this, "exclude-pattern", + llvm::cl::desc("Comma-separated tokens to exclude certain patterns " + "(e.g., load-subview)")}; + Option controlAttr{ + *this, "control-attr", + llvm::cl::desc( + "Attribute name that disables rewrites when present on the " + "matched operation"), + llvm::cl::init("")}; + + void runOnOperation() override; +}; + +void TestFoldMemRefAliasOptionsPass::runOnOperation() { + // Custom version of "FoldMemRefAliasOps" to test its options, by: + // 1) Excluding patterns that fold memref.subview into load ops + // 2) Ignoring user ops that have a specific attribute. + + // Map friendly tokens to concrete pattern names expected by the exclusion + // mechanism. + SmallVector disabledPatternNames; + if (llvm::is_contained(excludedPatternTokens, "load-subview")) { + // Resolve pattern debug names from a populated set. + RewritePatternSet patternsSet(&getContext()); + memref::populateFoldMemRefAliasOpPatterns(patternsSet); + for (auto &pattern : patternsSet.getNativePatterns()) { + std::optional rootKind = pattern->getRootKind(); + if (rootKind && + rootKind->getStringRef() == memref::LoadOp::getOperationName()) { + disabledPatternNames.push_back(pattern->getDebugName().str()); + break; + } + } + } + + std::function controlFnStorage; + function_ref controlFnRef; + if (!controlAttr.empty()) { + StringAttr attrName = StringAttr::get(&getContext(), controlAttr); + controlFnStorage = [attrName](Operation *op) { + return !op->hasAttr(attrName); + }; + controlFnRef = controlFnStorage; + } + + RewritePatternSet owningPatterns(&getContext()); + memref::populateFoldMemRefAliasOpPatterns(owningPatterns, controlFnRef); + FrozenRewritePatternSet patterns(std::move(owningPatterns), + disabledPatternNames); + (void)applyPatternsGreedily(getOperation(), patterns); +} +} // namespace + +namespace mlir { +namespace test { +void registerTestFoldMemRefAliasOptionsPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index a427132247e6d..d674114d4a18b 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -133,6 +133,7 @@ void registerTestMemRefToLLVMWithTransforms(); void registerTestReshardingPartitionPass(); void registerTestShardSimplificationsPass(); void registerTestMultiBuffering(); +void registerTestFoldMemRefAliasOptionsPass(); void registerTestNextAccessPass(); void registerTestNVGPULowerings(); void registerTestOpenACC(); @@ -248,6 +249,7 @@ static void registerTestPasses() { mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); mlir::test::registerTestRemarkPass(); + mlir::test::registerTestFoldMemRefAliasOptionsPass(); mlir::test::registerTestEmulateNarrowTypePass(); mlir::test::registerTestFooAnalysisPass(); mlir::test::registerTestComposeSubView();