Skip to content

Commit

Permalink
[mlir][Linalg] Add pattern for folding reshape by collapsing.
Browse files Browse the repository at this point in the history
Fusion of `linalg.generic` with
`tensor.expand_shape/tensor.collapse_shape` currently handles fusion
with reshape by expanding the dimensionality of the `linalg.generic`
operation. This helps fuse elementwise operations better since they
are fused at the highest dimensionality while keeping all indexing
maps involved projected permutations. The intent of these is to push
the reshape to the boundaries of functions.

The presence of named ops (or other ops across which the reshape
cannot be propagated) stops the propagation to the edges of the
function. At this stage, the converse patterns that fold the reshapes
with generic ops by collapsing the dimensions of the generic op can
push the reshape towards edges. In particular it helps the case where
reshapes exist in between named ops and generic ops.

`linalg.named_op` -> `tensor.expand_shape` -> `linalg.generic`

Pushing the reshape down will help fusion of `linalg.named_op` ->
`linalg.generic` using tile + fuse transformations.

This pattern is intended to replace the following patterns

1) FoldReshapeByLinearization : These patterns create indexing maps
that are not projected permutations that affect future
transformations. They are only useful for folding unit-dimensions.
2) PushReshapeByExpansion : This pattern has the same functionality
but has some restrictions
    a) It tries to avoid creating new reshapes that limits its
    applicability. The pattern added here can achieve the same
    functionality through use of the `controlFn` that allows clients
    of the pattern freedom to make this decision.
    b) It does not work for ops with indexing semantics.

These patterns will be deprecated in a future patch.

Differential Revision: https://reviews.llvm.org/D119365
  • Loading branch information
Mahesh Ravishankar committed Feb 16, 2022
1 parent 2e2f315 commit 2c58cde
Show file tree
Hide file tree
Showing 5 changed files with 935 additions and 3 deletions.
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,22 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
return *(indexingMaps.begin() + opOperand->getOperandNumber());
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the indexing map for a `result`.
}],
/*retTy=*/"AffineMap",
/*methodName=*/"getTiedIndexingMapForResult",
/*args=*/(ins "OpResult":$result),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(result.getOwner() == this->getOperation());
auto indexingMaps =
$_op.indexing_maps().template getAsValueRange<AffineMapAttr>();
return *(indexingMaps.begin() + getNumInputs() +
result.getResultNumber());
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the result tied to `opOperand`.
Expand Down
15 changes: 15 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,30 @@ void populateFoldReshapeOpsByExpansionPatterns(
const ControlElementwiseOpsFusionFn &controlFoldingReshapes =
skipUnitDimReshape);

/// Patterns to fold an expanding tensor.expand_shape operation with its
/// producer generic operation by collapsing the dimensions of the generic op.
void populateFoldReshapeOpsByCollapsingPatterns(
RewritePatternSet &patterns,
const ControlElementwiseOpsFusionFn &controlFoldingReshapes =
[](const OpResult & /*producer*/, OpOperand & /*consumer*/) {
return true;
});

/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic operation by linearizing the indexing map used
/// to access the source (target) of the reshape operation in the generic
/// operation.
/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
/// the `populateFoldReshapeByCollapsingPatterns`.
void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);

/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic operation by linearizing the indexing map used
/// to access the source (target) of the reshape operation in the generic
/// operation. The patterns are applied only when the tensor reshape involved is
/// collapsing (introducing) unit-extent dimensions.
/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
/// the `populateFoldReshapeByCollapsingPatterns`.
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
RewritePatternSet &patterns);

Expand Down Expand Up @@ -153,6 +166,8 @@ void populateElementwiseOpsFusionPatterns(

/// Patterns to push reshape op towards the end of the graph in order to expose
/// more fusion opportunities.
/// TODO(ravishankarm): These patterns are to be deprecated in favor of using
/// the `populateFoldReshapeByCollapsingPatterns`.
void populatePushReshapeOpsPatterns(RewritePatternSet &patterns);

/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
Expand Down
Loading

0 comments on commit 2c58cde

Please sign in to comment.