Skip to content

Commit

Permalink
[mlir][Linalg] Generalize the logic to compute reassociation maps
Browse files Browse the repository at this point in the history
while folding tensor_reshape op.

While folding reshapes that introduce unit extent dims, the logic to
compute the reassociation maps can be generalized to handle some
corner cases, for example, when the folded shape still has unit-extent
dims but corresponds to folded unit extent dims of the expanded shape.

Differential Revision: https://reviews.llvm.org/D88521
  • Loading branch information
Mahesh Ravishankar committed Sep 30, 2020
1 parent 3a7487f commit 892fdc9
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 45 deletions.
87 changes: 42 additions & 45 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Expand Up @@ -403,61 +403,58 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
srcType.getRank() < dstType.getRank() ||
parentSrcType.getRank() == dstType.getRank())
return failure();

// Check if the result tensor_reshape after folding the reshapeOp and
// parentReshapeOp are combined.
// If the final tensor_reshape is folding, the parentReshapeOp is
// introducing unit-dims, and the reshapeOp does an actual reshape.
// If the final tensor_reshape op is expanding, the reshapeOp is introducing
// unit-dims, and the parentReshapeOp does an actual reshape.
// If the final tensor_reshape op is expanding, the reshapeOp is
// introducing unit-dims, and the parentReshapeOp does an actual reshape.
bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
auto reassociationMaps = isFoldingPattern
? reshapeOp.getReassociationMaps()
: parentReshapeOp.getReassociationMaps();
DenseSet<unsigned> conservedDimensions;
for (auto &map : reassociationMaps) {
if (map.getNumResults() == 1) {
conservedDimensions.insert(
map.getResult(0).cast<AffineDimExpr>().getPosition());
}
}

// Find positions at which the unit-dims exist.
int64_t nonUnitDimPos = 0;
DenseMap<unsigned, unsigned> nonUnitSrcDims;
ArrayRef<int64_t> nonUnitShape =
ArrayRef<int64_t> expandedShape =
isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
for (auto shape : enumerate(srcType.getShape())) {
// Case 1 : It is a conserved dimension.
if (conservedDimensions.count(shape.index())) {
nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
continue;
ArrayRef<int64_t> foldedShape =
isFoldingPattern ? dstType.getShape() : parentSrcType.getShape();

unsigned expandedDim = 0, foldedDim = 0;
SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs(
foldedShape.size());
while (expandedDim < expandedShape.size() &&
foldedDim < foldedShape.size()) {
int64_t dstSize = foldedShape[foldedDim];
int64_t srcSize = expandedShape[expandedDim];
while (srcSize < dstSize && expandedDim < expandedShape.size()) {
reassociationExprs[foldedDim].push_back(
rewriter.getAffineDimExpr(expandedDim++));
srcSize *= expandedShape[expandedDim];
}
// Case 2 : Dimensions dont match but the intermediate tensor is unit-dim.
if (shape.value() == 1)
continue;
// Case 3 : Dimensions match, treat it as a non-unit src dim.
if (nonUnitDimPos < static_cast<int64_t>(nonUnitShape.size()) &&
nonUnitShape[nonUnitDimPos] == shape.value()) {
nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
continue;
if (srcSize == dstSize) {
reassociationExprs[foldedDim].push_back(
rewriter.getAffineDimExpr(expandedDim++));
// If the next dim in foldedShape is not 1, treat subsequent dims in
// expandedShape which are 1 to be collapsed.
if (foldedDim == foldedShape.size() - 1 ||
foldedShape[foldedDim + 1] != 1) {
while (expandedDim < expandedShape.size() &&
expandedShape[expandedDim] == 1) {
reassociationExprs[foldedDim].push_back(
rewriter.getAffineDimExpr(expandedDim++));
}
}
} else {
return failure();
}
return failure();
foldedDim++;
}
if (expandedDim != expandedShape.size())
return failure();

// Compute reassociation maps for the final operation. Use the reassociation
// maps that is actually doing a reshape (and not just introducing
// unit-dims). From these maps, prune the unit-extent dimensions.
for (AffineMap &map : reassociationMaps) {
SmallVector<AffineExpr, 4> exprs;
exprs.reserve(nonUnitSrcDims.size());
for (auto result : map.getResults()) {
unsigned dim = result.cast<AffineDimExpr>().getPosition();
if (nonUnitSrcDims.count(dim))
exprs.push_back(rewriter.getAffineDimExpr(nonUnitSrcDims[dim]));
}
map = AffineMap::get(nonUnitSrcDims.size(), 0, exprs,
rewriter.getContext());
}
SmallVector<AffineMap, 4> reassociationMaps =
llvm::to_vector<4>(llvm::map_range(
reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
return AffineMap::get(expandedShape.size(), 0, exprs,
rewriter.getContext());
}));
rewriter.replaceOpWithNewOp<TensorReshapeOp>(
reshapeOp, dstType, parentReshapeOp.src(),
rewriter.getAffineMapArrayAttr(reassociationMaps));
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Expand Up @@ -240,3 +240,19 @@ func @fold_reshape(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32>
: tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
return %1 : tensor<4x512x1x512x4xf32>
}

// -----

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @fold_reshape
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]
// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32>
func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
{
%0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : tensor<2xf32> into tensor<2x1x1xf32>
%1 = linalg.tensor_reshape %0
[affine_map<(d0, d1, d2) -> (d0)>,
affine_map<(d0, d1, d2) -> (d1, d2)>
] : tensor<2x1x1xf32> into tensor<2x1xf32>
return %1 : tensor<2x1xf32>
}

0 comments on commit 892fdc9

Please sign in to comment.