Skip to content

Commit

Permalink
[mlir][bufferize] Fix repetitive region conflict detection
Browse files Browse the repository at this point in the history
This fixes a bug where a required buffer copy was not inserted.

Not only written aliases, but also read aliases should be taken into account when computing common enclosing repetitive regions. Furthermore, for writing ops, it does not matter where the destination tensor is defined, but where the op itself is located.

Differential Revision: https://reviews.llvm.org/D135420
  • Loading branch information
matthias-springer committed Oct 7, 2022
1 parent a6a0d9e commit 2e21003
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 70 deletions.
187 changes: 117 additions & 70 deletions mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
Expand Up @@ -376,21 +376,6 @@ getEnclosingRepetitiveRegion(Value value, const BufferizationOptions &options) {
return nullptr;
}

/// For each given value, find the closest enclosing repetitive region. If this
/// is the same region for each value, return it. Otherwise return None.
/// Note: If there is no enclosing repetitive region, return nullptr.
static Optional<Region *>
getCommonEnclosingRepetitiveRegion(ArrayRef<Value> values,
const BufferizationOptions &options) {
if (values.empty())
return None;
Region *r = getEnclosingRepetitiveRegion(values.front(), options);
for (Value value : values.drop_front())
if (getEnclosingRepetitiveRegion(value, options) != r)
return None;
return r;
}

/// Return `true` if the given tensor value is a memory write. Most values are
/// tensor writes, but ops that define a tensor SSA value without specifying its
/// contents (e.g., alloc_tensor) are not.
Expand All @@ -404,6 +389,118 @@ static bool isMemoryWrite(Value value, const AnalysisState &state) {
return bufferizableOp.isMemoryWrite(opResult, state);
}

/// Return `true` if op dominance can be used to rule out read-after-write
/// conflicts wrt. the given reads and writes.
///
/// Op dominance can often be used to rule out potential conflicts such as
/// "read" happens before "write". E.g., the following IR is not a RaW conflict
/// because the the read happens *before* the write.
///
/// %0 = ... : tensor<?xf32>
/// "reading_op"(%0) : tensor<?xf32>
/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
///
/// This is no longer true inside loops (or repetitive regions). In such cases,
/// there may not be a meaningful `happensBefore` relationship because ops
/// could be executed multiple times. E.g.:
///
/// %0 = ... : tensor<?xf32>
/// scf.for ... {
/// "reading_op"(%0) : tensor<?xf32>
/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
/// ...
/// }
///
/// In the above example, reading_op happens before writing_op according to
/// op dominance. However, both ops may happen multiple times; in
/// particular, the second execution of reading_op happens after the first
/// execution of writing_op. This is problematic because the tensor %0 they
/// operate on (i.e., the "definition") is defined outside of the loop.
///
/// Counter example:
///
/// scf.for ... {
/// %0 = ... : tensor<?xf32>
/// "reading_op"(%0) : tensor<?xf32>
/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
/// ...
/// }
///
/// In this example, the definition %0 is in the same repetitive region as
/// "writing_op", so op dominance can be used to compute the `happensBefore`
/// relationship.
///
/// This functions finds the closest enclosing repetitive region of all buffer
/// writes wrt. the given given tensor reads and writes. If this is the same
/// region (nullptr in case of "no repetitive region" found at all), op
/// dominance can be used. Otherwise, it cannot be used.
///
/// Example: The common enclosing repetitive region is the scf.for loop.
/// Op dominance can be used.
/// scf.for ... {
/// %0 = tensor.generate
/// "read"(%0)
/// }
///
/// Example: The common enclosing repetitive region is nullptr: There is no
/// repetitive region around the tensor.generate. Op dominance can be
/// used.
/// %0 = tensor.generate
/// scf.for ... { "read"(%0) }
///
/// Example: The common enclosing repetitive regions of tensor.generate and
/// "write" differ. Op dominance cannot be used.
/// %0 = tensor.generate
/// scf.for ... {
/// "read"(%0)
/// "write"(%0)
/// }
///
/// Example: The common enclosing repetitive regions of tensor.generate and
/// "write" differ, but there is no read of %0, so op dominance can be
/// used.
/// %0 = tensor.generate
/// scf.for ... {
/// "write"(%0)
/// }
///
/// Note: iter_args of loops are not aliases of their respective block
/// arguments, so op domanice can be used when analyzing ops that operate
/// on them.
bool canUseOpDominance(const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite,
const AnalysisState &state) {
const BufferizationOptions &options = state.getOptions();
Optional<Region *> commonEnclosingRegion = None;

// In case of a write, take the region in which the write takes place.
for (OpOperand *uWrite : usesWrite) {
Region *r = getEnclosingRepetitiveRegion(uWrite->getOwner(), options);
if (!commonEnclosingRegion.has_value()) {
commonEnclosingRegion = r;
continue;
}
if (*commonEnclosingRegion != r)
return false;
}

// In case of a read, take the region which the read value is defined.
for (OpOperand *uRead : usesRead) {
// Optimization: Skip reads of values that have no defined contents.
if (!isMemoryWrite(uRead->get(), state))
continue;
Region *r = getEnclosingRepetitiveRegion(uRead->get(), options);
if (!commonEnclosingRegion.has_value()) {
commonEnclosingRegion = r;
continue;
}
if (*commonEnclosingRegion != r)
return false;
}

return commonEnclosingRegion.has_value();
}

/// Annotate IR with details about the detected RaW conflict.
static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
Value lastWrite) {
Expand Down Expand Up @@ -450,15 +547,8 @@ static bool hasReadAfterWriteInterference(
AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
const BufferizationOptions &options = state.getOptions();

// Gather all written aliases. Skip over aliases that are not actual writes.
SmallVector<Value> writtenAliases;
for (OpOperand *uWrite : usesWrite)
if (isMemoryWrite(uWrite->get(), state))
writtenAliases.push_back(uWrite->get());
// Find the inner-most enclosing repetitive region of each alias. If this is
// the same region for every alias, save it in `repetitiveRegionOfWrites`.
Optional<Region *> repetitiveRegionOfWrites =
getCommonEnclosingRepetitiveRegion(writtenAliases, options);
// Check if op dominance can be used to rule out read-after-write conflicts.
bool useDominance = canUseOpDominance(usesRead, usesWrite, state);

for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
Expand All @@ -482,55 +572,12 @@ static bool hasReadAfterWriteInterference(
// met for uConflictingWrite to be an actual conflict.
Operation *conflictingWritingOp = uConflictingWrite->getOwner();

// Check if conflictingWritingOp is in the same repetitive region as all
// written aliases. If this is not the case, there is no meaningful
// `happensBefore` relationship because conflictingWritingOp may be
// executed multiple times. E.g.:
//
// %0 = ... : tensor<?xf32>
// scf.for ... {
// "reading_op"(%0) : tensor<?xf32>
// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
// ...
// }
//
// In the above example, reading_op happens before writing_op according to
// op dominance. However, both ops may happen multiple times; in
// particular, the second execution of reading_op happens after the first
// execution of writing_op. This is problematic if the tensor they operate
// on (%0) is defined outside of the loop.
//
// Counter example:
//
// scf.for ... {
// %0 = ... : tensor<?xf32>
// "reading_op"(%0) : tensor<?xf32>
// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
// ...
// }
//
// In this example, %0 is in the same repetitive region as
// conflictingWritingOp, so op dominance can be used to compute the
// `happensBefore` relationship.
//
// Note: iter_args of loops are not aliases of their respective block
// arguments, so op domanice can be used when analyzing ops that operate
// on them.
//
// Note: If `writtenAliases` is empty, there are no memory writes outside
// of the repetitive region of conflictingWritingOp, which means that all
// relevant aliases are inside the same repetitive region.
bool canUseOpDominance =
writtenAliases.empty() ||
repetitiveRegionOfWrites ==
getEnclosingRepetitiveRegion(conflictingWritingOp, options);

// No conflict if the readingOp dominates conflictingWritingOp, i.e., the
// write is not visible when reading.
//
// Note: If ops are executed multiple times (e.g., because they are inside
// a loop), there may be no meaningful `happensBefore` relationship.
if (canUseOpDominance &&
if (useDominance &&
happensBefore(readingOp, conflictingWritingOp, domInfo))
continue;

Expand All @@ -540,7 +587,7 @@ static bool hasReadAfterWriteInterference(
// Note: Just being the same op is not enough. It has to be the same use.
// Note: If the op is executed multiple times (e.g., because it is inside
// a loop), it may be conflicting with itself.
if (canUseOpDominance && uConflictingWrite == uRead)
if (useDominance && uConflictingWrite == uRead)
continue;

// No conflict if the op interface says so.
Expand All @@ -559,7 +606,7 @@ static bool hasReadAfterWriteInterference(
// Note: If ops are executed multiple times (e.g., because they are inside
// a loop), mutually exclusive regions may be executed multiple
// times.
if (canUseOpDominance &&
if (useDominance &&
insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
continue;

Expand Down
66 changes: 66 additions & 0 deletions mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir
Expand Up @@ -630,3 +630,69 @@ func.func @same_enclosing_repetitive_region(%2: tensor<320xf32>,
} {thread_dim_mapping = []}
return %4 : tensor<320xf32>
}

// -----

// CHECK-LABEL: different_repetitive_region_via_alias
func.func @different_repetitive_region_via_alias(%arg0: tensor<4xf32>,
%arg1: tensor<4xf32>,
%arg2: index,
%arg3: index,
%arg4: index)
-> (tensor<4xf32>)
{
%cst = arith.constant 0.000000e+00 : f32
%cst2 = arith.constant 1.000000e+00 : f32
%0 = bufferization.alloc_tensor() : tensor<4xf32>

// CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "false"]}
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4xf32>) -> tensor<4xf32>

%2 = scf.for %arg5 = %arg2 to %arg3 step %arg4 iter_args(%arg6 = %arg1) -> (tensor<4xf32>) {
// CHECK: tensor.extract {{.*}} {__inplace_operands_attr__ = ["true", "none"]}
%4 = tensor.extract %1[%arg4] : tensor<4xf32>
vector.print %4 : f32
// CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]}
%5 = linalg.fill ins(%cst2 : f32) outs(%0 : tensor<4xf32>) -> tensor<4xf32>
scf.yield %5 : tensor<4xf32>
}

return %2 : tensor<4xf32>
}

// -----

// CHECK-LABEL: no_raw_conflict_after_repetitive_use
func.func @no_raw_conflict_after_repetitive_use(%arg0: tensor<4xf32>,
%arg1: tensor<4xf32>,
%arg2: index,
%arg3: index,
%arg4: index)
-> (tensor<4xf32>, tensor<4xf32>)
{
%cst = arith.constant 0.000000e+00 : f32
%cst2 = arith.constant 1.000000e+00 : f32
%cst3 = arith.constant 2.000000e+00 : f32
%0 = bufferization.alloc_tensor() : tensor<4xf32>

// CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]}
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4xf32>) -> tensor<4xf32>

%2 = scf.for %arg5 = %arg2 to %arg3 step %arg4 iter_args(%arg6 = %arg1) -> (tensor<4xf32>) {
// CHECK: tensor.extract {{.*}} {__inplace_operands_attr__ = ["true", "none"]}
%4 = tensor.extract %1[%arg4] : tensor<4xf32>
vector.print %4 : f32
// CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "false"]}
%5 = linalg.fill ins(%cst2 : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32>
scf.yield %5 : tensor<4xf32>
}

// The following is *not* a RaW conflict.
// CHECK: tensor.extract {{.*}} {__inplace_operands_attr__ = ["true", "none"]}
%6 = tensor.extract %1[%arg4] : tensor<4xf32>
vector.print %6 : f32
// CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]}
%7 = linalg.fill ins(%cst3 : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32>

return %2, %7 : tensor<4xf32>, tensor<4xf32>
}

0 comments on commit 2e21003

Please sign in to comment.