diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h index 8d44d579fbc8d..2663e480f281b 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -167,6 +167,9 @@ class OneShotAnalysisState : public AnalysisState { /// not be called for values inside not yet analyzed functions. bool isValueWritten(Value value) const; + /// Return true if the buffer of the given tensor value is writable. + bool isWritable(Value value) const; + private: /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal /// functions and `runOneShotBufferize` may access this object. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index f29d9a96b0b37..5447f6b0bdc23 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -305,6 +305,21 @@ bool OneShotAnalysisState::isValueWritten(Value value) const { return isWritten; } +bool OneShotAnalysisState::isWritable(Value value) const { + // TODO: Out-of-place bufferized value could be considered writable. + if (auto bufferizableOp = getOptions().dynCastBufferizableOp(value)) + return bufferizableOp.isWritable(value, *this); + + // Query BufferizableOpInterface to see if the BlockArgument is writable. + if (auto bbArg = value.dyn_cast()) + if (auto bufferizableOp = + getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp())) + return bufferizableOp.isWritable(bbArg, *this); + + // Not a bufferizable op: The conservative answer is "not writable". + return false; +} + //===----------------------------------------------------------------------===// // Bufferization-specific alias analysis. //===----------------------------------------------------------------------===// @@ -312,7 +327,7 @@ bool OneShotAnalysisState::isValueWritten(Value value) const { /// Return true if opOperand has been decided to bufferize in-place. static bool isInplaceMemoryWrite(OpOperand &opOperand, const BufferizationAliasInfo &aliasInfo, - AnalysisState &state) { + const AnalysisState &state) { // OpOperands that do not bufferize to a memory write do not write in-place. if (!state.bufferizesToMemoryWrite(opOperand)) return false; @@ -320,49 +335,6 @@ static bool isInplaceMemoryWrite(OpOperand &opOperand, return aliasInfo.isInPlace(opOperand); } -/// Return true if, under current bufferization decisions, the buffer of `value` -/// is not writable. -static bool aliasesNonWritableBuffer(Value value, - const BufferizationAliasInfo &aliasInfo, - AnalysisState &state) { - bool foundNonWritableBuffer = false; - aliasInfo.applyOnAliases(value, [&](Value v) { - // Query BufferizableOpInterface to see if the value is writable. - // TODO: Out-of-place bufferized value could be considered writable. - if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v)) - if (bufferizableOp && bufferizableOp.isWritable(v, state)) - return; - - // Query BufferizableOpInterface to see if the BlockArgument is writable. - if (auto bbArg = v.dyn_cast()) - if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp( - bbArg.getOwner()->getParentOp())) - if (bufferizableOp.isWritable(bbArg, state)) - return; - - foundNonWritableBuffer = true; - }); - - return foundNonWritableBuffer; -} - -/// Return true if the buffer to which `operand` would bufferize is equivalent -/// to some buffer write. -static bool aliasesInPlaceWrite(Value value, - const BufferizationAliasInfo &aliasInfo, - AnalysisState &state) { - bool foundInplaceWrite = false; - aliasInfo.applyOnAliases(value, [&](Value v) { - for (auto &use : v.getUses()) { - if (isInplaceMemoryWrite(use, aliasInfo, state)) { - foundInplaceWrite = true; - return; - } - } - }); - return foundInplaceWrite; -} - /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors /// properly dominates `b` and `b` is not inside `a`. static bool happensBefore(Operation *a, Operation *b, @@ -604,6 +576,30 @@ static bool hasReadAfterWriteInterference( return false; } +// Helper function to iterate on aliases of `root` and capture the writes. +static void getAliasingInplaceWrites(DenseSet &res, Value root, + const BufferizationAliasInfo &aliasInfo, + const AnalysisState &state) { + aliasInfo.applyOnAliases(root, [&](Value alias) { + for (auto &use : alias.getUses()) + // Inplace write to a value that aliases root. + if (isInplaceMemoryWrite(use, aliasInfo, state)) + res.insert(&use); + }); +} + +// Helper function to iterate on aliases of `root` and capture the reads. +static void getAliasingReads(DenseSet &res, Value root, + const BufferizationAliasInfo &aliasInfo, + const AnalysisState &state) { + aliasInfo.applyOnAliases(root, [&](Value alias) { + for (auto &use : alias.getUses()) + // Read to a value that aliases root. + if (state.bufferizesToMemoryRead(use)) + res.insert(&use); + }); +} + /// Return true if bufferizing `operand` inplace would create a conflict. A read /// R and a write W of the same alias set is a conflict if inplace bufferization /// of W changes the value read by R to a value different from the one that @@ -637,33 +633,13 @@ static bool wouldCreateReadAfterWriteInterference( OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state, const BufferizationAliasInfo &aliasInfo, bool checkConsistencyOnly = false) { - // Helper function to iterate on aliases of `root` and capture the reads. - auto getAliasingReads = [&](DenseSet &res, Value root) { - aliasInfo.applyOnAliases(root, [&](Value alias) { - for (auto &use : alias.getUses()) - // Read to a value that aliases root. - if (state.bufferizesToMemoryRead(use)) - res.insert(&use); - }); - }; - - // Helper function to iterate on aliases of `root` and capture the writes. - auto getAliasingInplaceWrites = [&](DenseSet &res, Value root) { - aliasInfo.applyOnAliases(root, [&](Value alias) { - for (auto &use : alias.getUses()) - // Inplace write to a value that aliases root. - if (isInplaceMemoryWrite(use, aliasInfo, state)) - res.insert(&use); - }); - }; - // Collect reads and writes of all aliases of OpOperand and OpResult. DenseSet usesRead, usesWrite; - getAliasingReads(usesRead, operand.get()); - getAliasingInplaceWrites(usesWrite, operand.get()); + getAliasingReads(usesRead, operand.get(), aliasInfo, state); + getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); for (OpResult result : state.getAliasingOpResult(operand)) { - getAliasingReads(usesRead, result); - getAliasingInplaceWrites(usesWrite, result); + getAliasingReads(usesRead, result, aliasInfo, state); + getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); } if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) usesWrite.insert(&operand); @@ -672,28 +648,60 @@ static bool wouldCreateReadAfterWriteInterference( aliasInfo); } -/// Return true if bufferizing `opOperand` inplace would create a write to a -/// non-writable buffer. +/// Check the reverse SSA use-def chain (following aliasing OpOperands) for +/// non-writable tensor values. Stop searching when an out-of-place bufferized +/// OpOperand was found (or when the OpOperand was not bufferized yet). +/// `currentOpOperand` is assumed to be in-place, even if that decision was not +/// materialized in `aliasInfo` yet. static bool -wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, - const BufferizationAliasInfo &aliasInfo, - AnalysisState &state) { - // Certain buffers are not writeable: - // 1. A function bbArg that is not inplaceable or - // 2. A constant op. - bool nonWritable = - aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state); - if (!nonWritable) - return false; +hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand, + const BufferizationAliasInfo &aliasInfo, + const OneShotAnalysisState &state) { + SmallVector worklist; + worklist.push_back(value); + while (!worklist.empty()) { + Value nextVal = worklist.pop_back_val(); + if (!state.isWritable(nextVal)) + return true; + + // If `nextVal` is not a BlockArgument: End of use-def chain reached. + auto opResult = nextVal.dyn_cast(); + if (!opResult) + continue; + + // Follow reverse SSA use-def chain. + SmallVector aliasingOpOperands = + state.getAliasingOpOperand(opResult); + for (OpOperand *opOperand : aliasingOpOperands) + if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand) + worklist.push_back(opOperand->get()); + } + return false; +} - // This is a problem only if the buffer is written to via some alias. - bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) || - state.bufferizesToMemoryWrite(opOperand); +/// Return true if bufferizing `operand` inplace would create a write to a +/// non-writable buffer. +static bool wouldCreateWriteToNonWritableBuffer( + OpOperand &operand, const BufferizationAliasInfo &aliasInfo, + OneShotAnalysisState &state, bool checkConsistencyOnly = false) { + // Collect writes of all aliases of OpOperand and OpResult. + DenseSet usesWrite; + getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state); + for (OpResult result : state.getAliasingOpResult(operand)) { + getAliasingInplaceWrites(usesWrite, result, aliasInfo, state); + } + if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand)) + usesWrite.insert(&operand); - for (OpResult opResult : state.getAliasingOpResult(opOperand)) - hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state); + // Assuming that `operand` bufferizes in-place: For each write (to each + // alias), check if there is a non-writable tensor in the reverse SSA use-def + // chain. + for (OpOperand *uWrite : usesWrite) + if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, + aliasInfo, state)) + return true; - return hasWrite; + return false; } //===----------------------------------------------------------------------===// @@ -702,8 +710,8 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, /// Determine if `operand` can be bufferized in-place. static LogicalResult bufferizableInPlaceAnalysisImpl( - OpOperand &operand, BufferizationAliasInfo &aliasInfo, AnalysisState &state, - const DominanceInfo &domInfo) { + OpOperand &operand, BufferizationAliasInfo &aliasInfo, + OneShotAnalysisState &state, const DominanceInfo &domInfo) { bool foundInterference = wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) || wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo); @@ -736,7 +744,7 @@ static LogicalResult bufferizableInPlaceAnalysisImpl( /// RaW dependence violations. static LogicalResult inPlaceAnalysis(SmallVector &ops, BufferizationAliasInfo &aliasInfo, - AnalysisState &state, + OneShotAnalysisState &state, const DominanceInfo &domInfo, unsigned analysisFuzzerSeed = 0) { if (analysisFuzzerSeed) { @@ -769,7 +777,7 @@ static bool hasTensorSemantics(Operation *op) { /// Analyze all ops that are contained in `op`. static LogicalResult inPlaceAnalysis(Operation *op, BufferizationAliasInfo &aliasInfo, - AnalysisState &state, + OneShotAnalysisState &state, const DominanceInfo &domInfo, unsigned analysisFuzzerSeed = 0) { // Collect ops so we can build our own reverse traversal. diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir index 9fff7f990b391..3145959e767e3 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir @@ -31,3 +31,34 @@ func.func @buffer_not_deallocated(%t : tensor, %c : i1) -> tensor // CHECK: return %[[r_tensor]] return %r : tensor } + +// ----- + +// CHECK-LABEL: func @write_to_alloc_tensor_or_readonly_tensor( +// CHECK-SAME: %[[arg0:.*]]: tensor +func.func @write_to_alloc_tensor_or_readonly_tensor(%arg0: tensor, + %cond: i1, %val: i32) + -> tensor +{ + // CHECK: %[[r:.*]] = scf.if {{.*}} { + // CHECK: %[[arg0_m:.*]] = bufferization.to_memref %[[arg0]] + // CHECK: %[[clone:.*]] = bufferization.clone %[[arg0_m]] + // CHECK: scf.yield %[[clone]] + // CHECK: } else { + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: memref.store %{{.*}}, %[[alloc]] + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK: scf.yield %[[casted]] + // CHECK: } + // CHECK: %[[r_t:.*]] = bufferization.to_tensor %[[r]] + // CHECK: memref.dealloc %[[r]] + // CHECK: return %[[r_t]] + %3 = scf.if %cond -> (tensor) { + scf.yield %arg0 : tensor + } else { + %7 = bufferization.alloc_tensor() : tensor + %8 = tensor.insert %val into %7[] : tensor + scf.yield %8 : tensor + } + return %3 : tensor +}