diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h index 9820a91291fdb..2ebf63fb8833b 100644 --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -63,13 +63,19 @@ class IntegerRangeAnalysis /// Visit block arguments or operation results of an operation with region /// control-flow for which values are not defined by region control-flow. This - /// function calls `InferIntRangeInterface` to provide values for block - /// arguments or tries to reduce the range on loop induction variables with + /// function tries to reduce the range on loop induction variables with /// known bounds. void visitNonControlFlowArguments( Operation *op, const RegionSuccessor &successor, ValueRange nonSuccessorInputs, ArrayRef nonSuccessorInputLattices) override; + + /// This function calls `InferIntRangeInterface` to provide values for entry + /// block arguments where the parentOp does not implement + /// `RegionBranchOpInterface` (e.g., gpu.launch). + void visitNonControlFlowArguments( + Operation *op, Region *const region, ValueRange arguments, + ArrayRef argLattices) override; }; /// Succeeds if an op can be converted to its unsigned equivalent without diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h index df50d8d193aeb..fb21c5bbb1310 100644 --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -218,6 +218,13 @@ class AbstractSparseForwardDataFlowAnalysis : public DataFlowAnalysis { ValueRange nonSuccessorInputs, ArrayRef nonSuccessorInputLattices) = 0; + /// Given an operation with region non-control-flow, the lattices of the entry + /// block arguments, compute the lattice values for block arguments.(ex. the + /// block arguments of gpu.launch). + virtual void visitNonControlFlowArgumentsImpl( + Operation *op, Region *const region, ValueRange arguments, + ArrayRef argLattices) = 0; + /// Get the lattice element of a value. virtual AbstractSparseLattice *getLatticeElement(Value value) = 0; @@ -335,6 +342,16 @@ class SparseForwardDataFlowAnalysis setAllToEntryStates(nonSuccessorInputLattices); } + /// Given an operation with region non-control-flow, the lattices of the entry + /// block arguments, compute the lattice values for block arguments.(ex. the + /// block argument of gpu.launch). By default, this method marks all lattice + /// elements as having reached a pessimistic fixpoint. + virtual void visitNonControlFlowArguments(Operation *op, Region *const region, + ValueRange arguments, + ArrayRef argLattices) { + setAllToEntryStates(argLattices); + } + protected: /// Get the lattice element for a value. StateT *getLatticeElement(Value value) override { @@ -391,6 +408,15 @@ class SparseForwardDataFlowAnalysis nonSuccessorInputLattices.size()}); } + virtual void visitNonControlFlowArgumentsImpl( + Operation *op, Region *const region, ValueRange arguments, + ArrayRef argLattices) override { + visitNonControlFlowArguments( + op, region, arguments, + {reinterpret_cast(argLattices.begin()), + argLattices.size()}); + } + void setToEntryState(AbstractSparseLattice *lattice) override { return setToEntryState(reinterpret_cast(lattice)); } diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index 7b567f043577a..79f31ea311211 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -143,50 +143,6 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( ArrayRef nonSuccessorInputLattices) { assert(nonSuccessorInputs.size() == nonSuccessorInputLattices.size() && "size mismatch"); - if (auto inferrable = dyn_cast(op)) { - LDBG() << "Inferring ranges for " - << OpWithFlags(op, OpPrintingFlags().skipRegions()); - - auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) { - return getLatticeElementFor(getProgramPointAfter(op), value)->getValue(); - }); - - auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { - auto arg = dyn_cast(v); - if (!arg) - return; - if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg)) - return; - - LDBG() << "Inferred range " << attrs; - auto it = llvm::find(successor.getSuccessor()->getArguments(), arg); - unsigned nonSuccessorInputIdx = - std::distance(successor.getSuccessor()->getArguments().begin(), it); - IntegerValueRangeLattice *lattice = - nonSuccessorInputLattices[nonSuccessorInputIdx]; - IntegerValueRange oldRange = lattice->getValue(); - - ChangeResult changed = lattice->join(attrs); - - // Catch loop results with loop variant bounds and conservatively make - // them [-inf, inf] so we don't circle around infinitely often (because - // the dataflow analysis in MLIR doesn't attempt to work out trip counts - // and often can't). - bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { - return op->hasTrait(); - }); - if (isYieldedValue && !oldRange.isUninitialized() && - !(lattice->getValue() == oldRange)) { - LDBG() << "Loop variant loop result detected"; - changed |= lattice->join(IntegerValueRange::getMaxRange(v)); - } - propagateIfChanged(lattice, changed); - }; - - inferrable.inferResultRangesFromOptional(argRanges, joinCallback); - return; - } - /// Given a lower bound, upper bound, or step from a LoopLikeInterface return /// the lower/upper bound for that result if possible. auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType, @@ -251,7 +207,51 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( } return; } - return SparseForwardDataFlowAnalysis::visitNonControlFlowArguments( op, successor, nonSuccessorInputs, nonSuccessorInputLattices); } + +void IntegerRangeAnalysis::visitNonControlFlowArguments( + Operation *op, Region *const region, ValueRange arguments, + ArrayRef argLattices) { + assert(arguments.size() == argLattices.size() && "size mismatch"); + if (auto inferrable = dyn_cast(op)) { + LDBG() << "Inferring ranges for " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); + + auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) { + return getLatticeElementFor(getProgramPointAfter(op), value)->getValue(); + }); + + auto joinCallback = [&](Value v, const IntegerValueRange &attrs) { + auto arg = dyn_cast(v); + if (!arg) + return; + if (!llvm::is_contained(arguments, arg)) + return; + + LDBG() << "Inferred range " << attrs; + auto it = llvm::find(arguments, arg); + unsigned argIndex = std::distance(arguments.begin(), it); + IntegerValueRangeLattice *lattice = argLattices[argIndex]; + IntegerValueRange oldRange = lattice->getValue(); + + ChangeResult changed = lattice->join(attrs); + + // Catch loop results with loop variant bounds and conservatively make + // them [-inf, inf] so we don't circle around infinitely often (because + // the dataflow analysis in MLIR doesn't attempt to work out trip counts + // and often can't). + bool isYieldedValue = llvm::any_of(v.getUsers(), [](Operation *op) { + return op->hasTrait(); + }); + if (isYieldedValue && !oldRange.isUninitialized() && + !(lattice->getValue() == oldRange)) { + LDBG() << "Loop variant loop result detected"; + changed |= lattice->join(IntegerValueRange::getMaxRange(v)); + } + propagateIfChanged(lattice, changed); + }; + inferrable.inferResultRangesFromOptional(argRanges, joinCallback); + } +} diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 90f2a588d1ca4..b583231aca9af 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -187,7 +187,7 @@ void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) { // All block arguments are non-successor-inputs. return visitNonControlFlowArgumentsImpl(block->getParentOp(), - RegionSuccessor(block->getParent()), + block->getParent(), block->getArguments(), argLattices); }