Skip to content

Commit 5466610

Browse files
committed
add DecomposeAggregatedOps
1 parent ca1373e commit 5466610

File tree

4 files changed

+59
-1
lines changed

4 files changed

+59
-1
lines changed

include/gc/Transforms/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> {
9292
/*default=*/"\"CPU\"",
9393
"The device to verify. Supported device: CPU, ">,
9494
];
95+
96+
def DecomposeAggregatedOps : Pass<"decompose-aggregated-ops", "func::FuncOp"> {
97+
let summary = "Decompose aggregated operations.";
98+
let description = [{
99+
Decompose operations that implement the `AggregatedOpInterface`.
100+
}];
95101
}
96102

97103
#endif // GC_DIALECT_GC_PASSES

lib/gc/Transforms/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ gc_add_mlir_library(GcPasses
1616
IterativeTilingAndFusion.cpp
1717
TilingUsingInterfaceX.cpp
1818
VerifyTargetDescription.cpp
19-
19+
DecomposeAggregatedOps.cpp
20+
2021
DEPENDS
2122
GraphCompilerPassIncGen
2223

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===- DecomposeAggregatedOps.cpp --------------------------------*- C++-*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/Transforms/Passes.h"
10+
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
14+
using namespace mlir;
15+
namespace mlir {
16+
namespace gc {
17+
#define GEN_PASS_DEF_DECOMPOSEAGGREGATEDOPS
18+
#include "gc/Transforms/Passes.h.inc"
19+
} // namespace gc
20+
} // namespace mlir
21+
22+
namespace {
23+
24+
struct DecomposeAggregateOpsImpl : public OpRewritePattern<linalg::SoftmaxOp> {
25+
using OpRewritePattern<linalg::SoftmaxOp>::OpRewritePattern;
26+
27+
LogicalResult matchAndRewrite(linalg::SoftmaxOp softmaxOp,
28+
PatternRewriter &rewriter) const override {
29+
auto decomposableOp =
30+
cast<linalg::AggregatedOpInterface>(softmaxOp.getOperation());
31+
FailureOr<SmallVector<Value>> maybeNewResult =
32+
decomposableOp.decomposeOperation(rewriter);
33+
if (failed(maybeNewResult))
34+
return failure();
35+
rewriter.replaceOp(softmaxOp, *maybeNewResult);
36+
return success();
37+
}
38+
};
39+
40+
struct DecomposeAggregatedOps
41+
: public gc::impl::DecomposeAggregatedOpsBase<DecomposeAggregatedOps> {
42+
void runOnOperation() override {
43+
RewritePatternSet patterns(getOperation().getContext());
44+
patterns.add<DecomposeAggregateOpsImpl>(patterns.getContext());
45+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
46+
}
47+
};
48+
49+
} // namespace

lib/gc/Transforms/Pipeline.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) {
5454
// REMOVE this pass after the above passes are added. Currently we add this
5555
// pass to make the pipeline work properly
5656
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
57+
// copied from tpp project
58+
pm.addNestedPass<func::FuncOp>(createDecomposeAggregatedOps());
5759
}
5860

5961
// scf + arith + math + vector + tensor + linalg.brgemm

0 commit comments

Comments
 (0)