diff --git a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp index 2dc77c9705d35..0f94d95408f29 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDeadCodeAnalysis.cpp @@ -66,40 +66,26 @@ static void printAnalysisResults(DataFlowSolver &solver, Operation *op, namespace { /// This is a simple analysis that implements a transfer function for constant /// operations. -struct ConstantAnalysis : public DataFlowAnalysis { - using DataFlowAnalysis::DataFlowAnalysis; +struct SparseConstantAnalysis + : public SparseForwardDataFlowAnalysis> { + using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis; - LogicalResult initialize(Operation *top) override { - WalkResult result = top->walk([&](Operation *op) { - if (failed(visit(getProgramPointAfter(op)))) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - return success(!result.wasInterrupted()); - } - - LogicalResult visit(ProgramPoint *point) override { - Operation *op = point->getPrevOp(); + LogicalResult + visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override { Attribute value; if (matchPattern(op, m_Constant(&value))) { auto *constant = getOrCreate>(op->getResult(0)); propagateIfChanged( constant, constant->join(ConstantValue(value, op->getDialect()))); - return success(); } - setAllToUnknownConstants(op->getResults()); - for (Region ®ion : op->getRegions()) - setAllToUnknownConstants(region.getArguments()); return success(); } - /// Set all given values as not constants. - void setAllToUnknownConstants(ValueRange values) { - for (Value value : values) { - auto *constant = getOrCreate>(value); - propagateIfChanged(constant, - constant->join(ConstantValue::getUnknownConstant())); - } + void setToEntryState(Lattice *lattice) override { + propagateIfChanged(lattice, + lattice->join(ConstantValue::getUnknownConstant())); } }; @@ -116,7 +102,7 @@ struct TestDeadCodeAnalysisPass DataFlowSolver solver; solver.load(); - solver.load(); + solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); printAnalysisResults(solver, op, llvm::errs());