/
TestDynamicPipeline.cpp
112 lines (100 loc) · 3.81 KB
/
TestDynamicPipeline.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
//===------ TestDynamicPipeline.cpp --- dynamic pipeline test pass --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to test the dynamic pipeline feature.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
namespace {
class TestDynamicPipelinePass
: public PassWrapper<TestDynamicPipelinePass, OperationPass<>> {
public:
void getDependentDialects(DialectRegistry ®istry) const override {
OpPassManager pm(ModuleOp::getOperationName(), false);
parsePassPipeline(pipeline, pm, llvm::errs());
pm.getDependentDialects(registry);
}
TestDynamicPipelinePass(){};
TestDynamicPipelinePass(const TestDynamicPipelinePass &) {}
void runOnOperation() override {
llvm::errs() << "Dynamic execute '" << pipeline << "' on "
<< getOperation()->getName() << "\n";
if (pipeline.empty()) {
llvm::errs() << "Empty pipeline\n";
return;
}
auto symbolOp = dyn_cast<SymbolOpInterface>(getOperation());
if (!symbolOp) {
getOperation()->emitWarning()
<< "Ignoring because not implementing SymbolOpInterface\n";
return;
}
auto opName = symbolOp.getName();
if (!opNames.empty() && !llvm::is_contained(opNames, opName)) {
llvm::errs() << "dynamic-pipeline skip op name: " << opName << "\n";
return;
}
if (!pm) {
pm = std::make_unique<OpPassManager>(
getOperation()->getName().getIdentifier(), false);
parsePassPipeline(pipeline, *pm, llvm::errs());
}
// Check that running on the parent operation always immediately fails.
if (runOnParent) {
if (getOperation()->getParentOp())
if (!failed(runPipeline(*pm, getOperation()->getParentOp())))
signalPassFailure();
return;
}
if (runOnNestedOp) {
llvm::errs() << "Run on nested op\n";
getOperation()->walk([&](Operation *op) {
if (op == getOperation() || !op->isKnownIsolatedFromAbove())
return;
llvm::errs() << "Run on " << *op << "\n";
// Run on the current operation
if (failed(runPipeline(*pm, op)))
signalPassFailure();
});
} else {
// Run on the current operation
if (failed(runPipeline(*pm, getOperation())))
signalPassFailure();
}
}
std::unique_ptr<OpPassManager> pm;
Option<bool> runOnNestedOp{
*this, "run-on-nested-operations",
llvm::cl::desc("This will apply the pipeline on nested operations under "
"the visited operation.")};
Option<bool> runOnParent{
*this, "run-on-parent",
llvm::cl::desc("This will apply the pipeline on the parent operation if "
"it exist, this is expected to fail.")};
Option<std::string> pipeline{
*this, "dynamic-pipeline",
llvm::cl::desc("The pipeline description that "
"will run on the filtered function.")};
ListOption<std::string> opNames{
*this, "op-name", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("List of function name to apply the pipeline to")};
};
} // end namespace
namespace mlir {
void registerTestDynamicPipelinePass() {
PassRegistration<TestDynamicPipelinePass>(
"test-dynamic-pipeline", "Tests the dynamic pipeline feature by applying "
"a pipeline on a selected set of functions");
}
} // namespace mlir