Skip to content

Commit

Permalink
[mlir][Affine] Add support for multi-store producer fusion
Browse files Browse the repository at this point in the history
This patch adds support for producer-consumer fusion scenarios with
multiple producer stores to the AffineLoopFusion pass. The patch
introduces some changes to the producer-consumer algorithm, including:

* For a given consumer loop, producer-consumer fusion iterates over its
producer candidates until a fixed point is reached.

* Producer candidates are gathered beforehand for each iteration of the
consumer loop and visited in reverse program order (not strictly guaranteed)
to maximize the number of loops fused per iteration.

In general, these changes were needed to simplify the multi-store producer
support and remove some of the workarounds that were introduced in the past
to support more fusion cases under the single-store producer limitation.

This patch also preserves the existing functionality of AffineLoopFusion with
one minor change in behavior. Producer-consumer fusion didn't fuse scenarios
with escaping memrefs and multiple outgoing edges (from a single store).
Multi-store producer scenarios will usually (always?) have multiple outgoing
edges so we couldn't fuse any with escaping memrefs, which would greatly limit
the applicability of this new feature. Therefore, the patch enables fusion for
these scenarios. Please, see modified tests for specific details.

Reviewed By: andydavis1, bondhugula

Differential Revision: https://reviews.llvm.org/D92876
  • Loading branch information
dcaballe committed Jan 20, 2021
1 parent fd70f70 commit 7dd1988
Show file tree
Hide file tree
Showing 9 changed files with 898 additions and 404 deletions.
15 changes: 15 additions & 0 deletions mlir/include/mlir/Analysis/AffineStructures.h
Expand Up @@ -234,6 +234,21 @@ class FlatAffineConstraints {
// TODO: add support for non-unit strides.
LogicalResult addAffineForOpDomain(AffineForOp forOp);

/// Adds constraints (lower and upper bounds) for each loop in the loop nest
/// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice.
/// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in
/// the nest, sorted outer-to-inner. 'operands' contains the bound operands
/// for a single bound map. All the bound maps will use the same bound
/// operands. Note that some loops described by a computation slice might not
/// exist yet in the IR so the Value attached to those dimension identifiers
/// might be empty. For that reason, this method doesn't perform Value
/// look-ups to retrieve the dimension identifier positions. Instead, it
/// assumes the position of the dim identifiers in the constraint system is
/// the same as the position of the loop in the loop nest.
LogicalResult addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
ArrayRef<AffineMap> ubMaps,
ArrayRef<Value> operands);

/// Adds constraints imposed by the `affine.if` operation. These constraints
/// are collected from the IntegerSet attached to the given `affine.if`
/// instance argument (`ifOp`). It is asserted that:
Expand Down
17 changes: 16 additions & 1 deletion mlir/include/mlir/Analysis/Utils.h
Expand Up @@ -83,10 +83,25 @@ struct ComputationSliceState {
// Clears all bounds and operands in slice state.
void clearBounds();

/// Return true if the computation slice is empty.
/// Returns true if the computation slice is empty.
bool isEmpty() const { return ivs.empty(); }

/// Returns true if the computation slice encloses all the iterations of the
/// sliced loop nest. Returns false if it does not. Returns llvm::None if it
/// cannot determine if the slice is maximal or not.
// TODO: Cache 'isMaximal' so that we don't recompute it when the slice
// information hasn't changed.
Optional<bool> isMaximal() const;

void dump() const;

private:
/// Fast check to determine if the computation slice is maximal. Returns true
/// if each slice dimension maps to an existing dst dimension and both the src
/// and the dst loops for those dimensions have the same bounds. Returns false
/// if both the src and the dst loops don't have the same bounds. Returns
/// llvm::None if none of the above can be proven.
Optional<bool> isSliceMaximalFastCheck() const;
};

/// Computes the computation slice loop bounds for one loop nest as affine maps
Expand Down
49 changes: 38 additions & 11 deletions mlir/include/mlir/Transforms/LoopFusionUtils.h
Expand Up @@ -50,7 +50,8 @@ struct FusionResult {
// TODO: Generalize utilities so that producer-consumer and sibling fusion
// strategies can be used without the assumptions made in the AffineLoopFusion
// pass.
struct FusionStrategy {
class FusionStrategy {
public:
enum StrategyEnum {
// Generic loop fusion: Arbitrary loops are considered for fusion. No
// assumptions about a specific fusion strategy from AffineLoopFusion pass
Expand All @@ -69,13 +70,34 @@ struct FusionStrategy {
// implementation in AffineLoopFusion pass are made. See pass for specific
// details.
Sibling
} strategy;
};

// Target memref for this fusion transformation.
Value memref;
/// Construct a generic or producer-consumer fusion strategy.
FusionStrategy(StrategyEnum strategy) : strategy(strategy) {
assert(strategy != Sibling &&
"Sibling fusion strategy requires a specific memref");
}

/// Construct a sibling fusion strategy targeting 'memref'. This construct
/// should only be used for sibling fusion.
FusionStrategy(Value memref) : strategy(Sibling), memref(memref) {}

/// Returns the fusion strategy.
StrategyEnum getStrategy() const { return strategy; };

FusionStrategy(StrategyEnum strategy, Value memref)
: strategy(strategy), memref(memref) {}
/// Returns the memref attached to this sibling fusion strategy.
Value getSiblingFusionMemRef() const {
assert(strategy == Sibling && "Memref is only valid for sibling fusion");
return memref;
}

private:
/// Fusion strategy.
StrategyEnum strategy;

/// Target memref for this fusion transformation. Only used for sibling
/// fusion.
Value memref;
};

/// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the
Expand All @@ -86,11 +108,10 @@ struct FusionStrategy {
/// NOTE: This function is not feature complete and should only be used in
/// testing.
/// TODO: Update comments when this function is fully implemented.
FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
unsigned dstLoopDepth,
ComputationSliceState *srcSlice,
FusionStrategy fusionStrategy = {
FusionStrategy::Generic, Value()});
FusionResult
canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth,
ComputationSliceState *srcSlice,
FusionStrategy fusionStrategy = FusionStrategy::Generic);

/// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point
/// and source slice loop bounds specified in 'srcSlice'.
Expand Down Expand Up @@ -134,6 +155,12 @@ bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
const ComputationSliceState &slice,
int64_t *computeCost);

/// Returns in 'producerConsumerMemrefs' the memrefs involved in a
/// producer-consumer dependence between write ops in 'srcOps' and read ops in
/// 'dstOps'.
void gatherProducerConsumerMemrefs(ArrayRef<Operation *> srcOps,
ArrayRef<Operation *> dstOps,
DenseSet<Value> &producerConsumerMemrefs);
} // end namespace mlir

#endif // MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H
105 changes: 105 additions & 0 deletions mlir/include/mlir/Transforms/Passes.td
Expand Up @@ -17,6 +17,111 @@ include "mlir/Pass/PassBase.td"

def AffineLoopFusion : FunctionPass<"affine-loop-fusion"> {
let summary = "Fuse affine loop nests";
let description = [{
This pass performs fusion of loop nests using a slicing-based approach. It
combines two fusion strategies: producer-consumer fusion and sibling fusion.
Producer-consumer fusion is aimed at fusing pairs of loops where the first
one writes to a memref that the second reads. Sibling fusion targets pairs
of loops that share no dependences between them but that load from the same
memref. The fused loop nests, when possible, are rewritten to access
significantly smaller local buffers instead of the original memref's, and
the latter are often either completely optimized away or contracted. This
transformation leads to enhanced locality and lower memory footprint through
the elimination or contraction of temporaries/intermediate memref's. These
benefits are sometimes achieved at the expense of redundant computation
through a cost model that evaluates available choices such as the depth at
which a source slice should be materialized in the designation slice.

Example 1: Producer-consumer fusion.
Input:
```mlir
func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
%0 = alloc() : memref<10xf32>
%1 = alloc() : memref<10xf32>
%cst = constant 0.000000e+00 : f32
affine.for %arg2 = 0 to 10 {
affine.store %cst, %0[%arg2] : memref<10xf32>
affine.store %cst, %1[%arg2] : memref<10xf32>
}
affine.for %arg2 = 0 to 10 {
%2 = affine.load %0[%arg2] : memref<10xf32>
%3 = addf %2, %2 : f32
affine.store %3, %arg0[%arg2] : memref<10xf32>
}
affine.for %arg2 = 0 to 10 {
%2 = affine.load %1[%arg2] : memref<10xf32>
%3 = mulf %2, %2 : f32
affine.store %3, %arg1[%arg2] : memref<10xf32>
}
return
}
```
Output:
```mlir
func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
%0 = alloc() : memref<1xf32>
%1 = alloc() : memref<1xf32>
%cst = constant 0.000000e+00 : f32
affine.for %arg2 = 0 to 10 {
affine.store %cst, %0[0] : memref<1xf32>
affine.store %cst, %1[0] : memref<1xf32>
%2 = affine.load %1[0] : memref<1xf32>
%3 = mulf %2, %2 : f32
affine.store %3, %arg1[%arg2] : memref<10xf32>
%4 = affine.load %0[0] : memref<1xf32>
%5 = addf %4, %4 : f32
affine.store %5, %arg0[%arg2] : memref<10xf32>
}
return
}
```

Example 2: Sibling fusion.
Input:
```mlir
func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>,
%arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>,
%arg4: memref<10x10xf32>) {
affine.for %arg5 = 0 to 3 {
affine.for %arg6 = 0 to 3 {
%0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
%1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32>
%2 = mulf %0, %1 : f32
affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32>
}
}
affine.for %arg5 = 0 to 3 {
affine.for %arg6 = 0 to 3 {
%0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
%1 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32>
%2 = addf %0, %1 : f32
affine.store %2, %arg4[%arg5, %arg6] : memref<10x10xf32>
}
}
return
}
```
Output:
```mlir
func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>,
%arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>,
%arg4: memref<10x10xf32>) {
affine.for %arg5 = 0 to 3 {
affine.for %arg6 = 0 to 3 {
%0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
%1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32>
%2 = mulf %0, %1 : f32
affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32>
%3 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
%4 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32>
%5 = addf %3, %4 : f32
affine.store %5, %arg4[%arg5, %arg6] : memref<10x10xf32>
}
}
return
}
```
}];
let constructor = "mlir::createLoopFusionPass()";
let options = [
Option<"computeToleranceThreshold", "fusion-compute-tolerance", "double",
Expand Down
64 changes: 64 additions & 0 deletions mlir/lib/Analysis/AffineStructures.cpp
Expand Up @@ -708,6 +708,70 @@ LogicalResult FlatAffineConstraints::addAffineForOpDomain(AffineForOp forOp) {
/*eq=*/false, /*lower=*/false);
}

/// Adds constraints (lower and upper bounds) for each loop in the loop nest
/// described by the bound maps 'lbMaps' and 'ubMaps' of a computation slice.
/// Every pair ('lbMaps[i]', 'ubMaps[i]') describes the bounds of a loop in
/// the nest, sorted outer-to-inner. 'operands' contains the bound operands
/// for a single bound map. All the bound maps will use the same bound
/// operands. Note that some loops described by a computation slice might not
/// exist yet in the IR so the Value attached to those dimension identifiers
/// might be empty. For that reason, this method doesn't perform Value
/// look-ups to retrieve the dimension identifier positions. Instead, it
/// assumes the position of the dim identifiers in the constraint system is
/// the same as the position of the loop in the loop nest.
LogicalResult
FlatAffineConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
ArrayRef<AffineMap> ubMaps,
ArrayRef<Value> operands) {
assert(lbMaps.size() == ubMaps.size());
assert(lbMaps.size() <= getNumDimIds());

for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
AffineMap lbMap = lbMaps[i];
AffineMap ubMap = ubMaps[i];
assert(!lbMap || lbMap.getNumInputs() == operands.size());
assert(!ubMap || ubMap.getNumInputs() == operands.size());

// Check if this slice is just an equality along this dimension. If so,
// retrieve the existing loop it equates to and add it to the system.
if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
ubMap.getNumResults() == 1 &&
lbMap.getResult(0) + 1 == ubMap.getResult(0) &&
// The condition above will be true for maps describing a single
// iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
// Make sure we skip those cases by checking that the lb result is not
// just a constant.
!lbMap.getResult(0).isa<AffineConstantExpr>()) {
// Limited support: we expect the lb result to be just a loop dimension.
// Not supported otherwise for now.
AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
if (!result)
return failure();

AffineForOp loop =
getForInductionVarOwner(operands[result.getPosition()]);
if (!loop)
return failure();

if (failed(addAffineForOpDomain(loop)))
return failure();
continue;
}

// This slice refers to a loop that doesn't exist in the IR yet. Add its
// bounds to the system assuming its dimension identifier position is the
// same as the position of the loop in the loop nest.
if (lbMap && failed(addLowerOrUpperBound(i, lbMap, operands, /*eq=*/false,
/*lower=*/true)))
return failure();

if (ubMap && failed(addLowerOrUpperBound(i, ubMap, operands, /*eq=*/false,
/*lower=*/false)))
return failure();
}
return success();
}

void FlatAffineConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
// Create the base constraints from the integer set attached to ifOp.
FlatAffineConstraints cst(ifOp.getIntegerSet());
Expand Down

0 comments on commit 7dd1988

Please sign in to comment.