Skip to content

Commit

Permalink
test composition
Browse files Browse the repository at this point in the history
  • Loading branch information
Mogball committed Jun 30, 2022
1 parent 1ed1e8c commit 81b151b
Showing 1 changed file with 76 additions and 3 deletions.
79 changes: 76 additions & 3 deletions mlir/test/lib/Analysis/TestDataFlowFramework.cpp
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/SparseDataFlowAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"

Expand Down Expand Up @@ -182,8 +183,80 @@ void TestFooAnalysisPass::runOnOperation() {
});
}

namespace {
struct AugmentSCP : public DataFlowAnalysis {
using DataFlowAnalysis::DataFlowAnalysis;

LogicalResult initialize(Operation *top) override {
top->walk([&](Operation *op) {
if (op->getName().getStringRef() == "test.scp_region")
(void)visit(op);
});
return success();
}

LogicalResult visit(ProgramPoint point) override {
auto *op = point.get<Operation *>();
assert(op->getName().getStringRef() == "test.scp_region");

auto *rhs = getOrCreateFor<ConstantValueState>(op, op->getOperand(0));
if (rhs->isUninitialized()) return success();

for (Region &region : op->getRegions()) {
for (Value value : region.getArguments()) {
assert(staticallyProvides(TypeID::get<ConstantValueState>(), value));
update<ConstantValueState>(
value, [rhs](ConstantValueState *lhs) { return lhs->join(*rhs); });
}
}
return success();
}

bool staticallyProvides(TypeID stateID, ProgramPoint point) const override {
if (stateID != TypeID::get<ConstantValueState>())
return false;

auto value = point.dyn_cast<Value>();
if (!value || !value.isa<BlockArgument>() ||
value.getParentBlock() != &value.getParentRegion()->front())
return false;

return value.getParentRegion()->getParentOp()->getName().getStringRef() ==
"test.scp_region";
}
};

struct AugmentSCPPass : public PassWrapper<AugmentSCPPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AugmentSCPPass)

StringRef getArgument() const override { return "test-augment-scp"; }

void runOnOperation() override {
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
solver.load<AugmentSCP>();
if (failed(solver.initializeAndRun(getOperation())))
return signalPassFailure();

getOperation()->walk([&](Operation *op) {
for (auto &result : llvm::enumerate(op->getResults())) {
auto *cv = solver.lookup<ConstantValueState>(result.value());
if (!cv || cv->isUninitialized() || !cv->getValue().getConstantValue())
continue;
llvm::errs() << "op " << op->getName() << " result #" << result.index()
<< " -> " << cv->getValue().getConstantValue() << "\n";
}
});
}
};
} // end anonymous namespace

namespace mlir {
namespace test {
void registerTestFooAnalysisPass() { PassRegistration<TestFooAnalysisPass>(); }
} // namespace test
} // namespace mlir
void registerTestFooAnalysisPass() {
PassRegistration<TestFooAnalysisPass>();
PassRegistration<AugmentSCPPass>();
}
} // end namespace test
} // end namespace mlir

0 comments on commit 81b151b

Please sign in to comment.