From 00fa95ca15508f74ce199b4c43abb1cf26094e2b Mon Sep 17 00:00:00 2001 From: linuxlonelyeagle <2020382038@qq.com> Date: Tue, 2 Sep 2025 16:25:06 +0000 Subject: [PATCH] use SparseForwardDataFlowAnalysis to implement constant analysis --- .../DataFlow/TestDeadCodeAnalysis.cpp | 36 ++++++------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp index 2dc77c9705d35..0f94d95408f29 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp @@ -66,40 +66,26 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op, namespace { /// This is a simple analysis that implements a transfer function for constant /// operations. -struct ConstantAnalysis : public DataFlowAnalysis { - using DataFlowAnalysis::DataFlowAnalysis; +struct SparseConstantAnalysis + : public SparseForwardDataFlowAnalysis> { + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; - LogicalResult initialize(Operation *top) override { - WalkResult result = top->walk([&](Operation *op) { - if (failed(visit(getProgramPointAfter(op)))) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - return success(!result.wasInterrupted()); - } - - LogicalResult visit(ProgramPoint *point) override { - Operation *op = point->getPrevOp(); + LogicalResult + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override { Attribute value; if (matchPattern(op, m_Constant(&value))) { auto *constant = getOrCreate>(op->getResult(0)); propagateIfChanged( constant, constant->join(ConstantValue(value, op->getDialect()))); - return success(); } - setAllToUnknownConstants(op->getResults()); - for (Region ®ion : op->getRegions()) - setAllToUnknownConstants(region.getArguments()); return success(); } - /// Set all given values as not constants. - void setAllToUnknownConstants(ValueRange values) { - for (Value value : values) { - auto *constant = getOrCreate>(value); - propagateIfChanged(constant, - constant->join(ConstantValue::getUnknownConstant())); - } + void setToEntryState(Lattice *lattice) override { + propagateIfChanged(lattice, + lattice->join(ConstantValue::getUnknownConstant())); } }; @@ -116,7 +102,7 @@ struct TestDeadCodeAnalysisPass DataFlowSolver solver; solver.load(); - solver.load(); + solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); printAnalysisResults(solver, op, llvm::errs());