Skip to content

Commit

Permalink
[mlir][bufferize][NFC] Optimize read-only tensor detection
Browse files Browse the repository at this point in the history
Check alias sets instead of traversing the IR.

Differential Revision: https://reviews.llvm.org/D143500
  • Loading branch information
matthias-springer committed Feb 9, 2023
1 parent cedfd27 commit dc7ad19
Showing 1 changed file with 27 additions and 47 deletions.
74 changes: 27 additions & 47 deletions mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
Expand Up @@ -724,62 +724,42 @@ static void annotateNonWritableTensor(Value value) {
}
}

/// Check the reverse SSA use-def chain (following aliasing OpOperands) for
/// non-writable tensor values. Stop searching when an out-of-place bufferized
/// OpOperand was found (or when the OpOperand was not bufferized yet).
/// `currentOpOperand` is assumed to be in-place, even if that decision was not
/// materialized in `aliasInfo` yet.
static bool
hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
const OneShotAnalysisState &state) {
SmallVector<Value> worklist;
worklist.push_back(value);
while (!worklist.empty()) {
Value nextVal = worklist.pop_back_val();
if (!state.isWritable(nextVal)) {
if (state.getOptions().printConflicts)
annotateNonWritableTensor(nextVal);
return true;
}

// If `nextVal` is not a BlockArgument: End of use-def chain reached.
auto opResult = nextVal.dyn_cast<OpResult>();
if (!opResult)
continue;

// Follow reverse SSA use-def chain.
AliasingOpOperandList aliasingOpOperands =
state.getAliasingOpOperands(opResult);
for (OpOperand *opOperand : aliasingOpOperands)
if (state.isInPlace(*opOperand) || currentOpOperand == opOperand)
worklist.push_back(opOperand->get());
}
return false;
}

/// Return true if bufferizing `operand` inplace would create a write to a
/// non-writable buffer.
static bool
wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
OneShotAnalysisState &state,
bool checkConsistencyOnly = false) {
// Collect writes of all aliases of OpOperand and OpResult.
DenseSet<OpOperand *> usesWrite;
getAliasingInplaceWrites(usesWrite, operand.get(), state);
for (OpResult result : state.getAliasingOpResults(operand)) {
getAliasingInplaceWrites(usesWrite, result, state);
bool foundWrite =
!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);

if (!foundWrite) {
// Collect writes of all aliases of OpOperand and OpResult.
DenseSet<OpOperand *> usesWrite;
getAliasingInplaceWrites(usesWrite, operand.get(), state);
for (OpResult result : state.getAliasingOpResults(operand))
getAliasingInplaceWrites(usesWrite, result, state);
foundWrite = !usesWrite.empty();
}
if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);

// Assuming that `operand` bufferizes in-place: For each write (to each
// alias), check if there is a non-writable tensor in the reverse SSA use-def
// chain.
for (OpOperand *uWrite : usesWrite) {
if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, state)) {
LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
return true;
if (!foundWrite)
return false;

// Look for a read-only tensor among all aliases.
bool foundReadOnly = false;
auto checkReadOnly = [&](Value v) {
if (!state.isWritable(v)) {
foundReadOnly = true;
if (state.getOptions().printConflicts)
annotateNonWritableTensor(v);
}
};
state.applyOnAliases(operand.get(), checkReadOnly);
for (OpResult result : state.getAliasingOpResults(operand))
state.applyOnAliases(result, checkReadOnly);
if (foundReadOnly) {
LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
return true;
}

return false;
Expand Down

0 comments on commit dc7ad19

Please sign in to comment.