From 892fdc923f06adbef507ebe594fa7b48224d93f0 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Tue, 29 Sep 2020 16:14:49 -0700 Subject: [PATCH] [mlir][Linalg] Generalize the logic to compute reassociation maps while folding tensor_reshape op. While folding reshapes that introduce unit extent dims, the logic to compute the reassociation maps can be generalized to handle some corner cases, for example, when the folded shape still has unit-extent dims but corresponds to folded unit extent dims of the expanded shape. Differential Revision: https://reviews.llvm.org/D88521 --- .../Linalg/Transforms/DropUnitDims.cpp | 87 +++++++++---------- .../Dialect/Linalg/drop-unit-extent-dims.mlir | 16 ++++ 2 files changed, 58 insertions(+), 45 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 08e7e352d63e9b..611c938ab542fd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -403,61 +403,58 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern { srcType.getRank() < dstType.getRank() || parentSrcType.getRank() == dstType.getRank()) return failure(); + // Check if the result tensor_reshape after folding the reshapeOp and // parentReshapeOp are combined. // If the final tensor_reshape is folding, the parentReshapeOp is // introducing unit-dims, and the reshapeOp does an actual reshape. - // If the final tensor_reshape op is expanding, the reshapeOp is introducing - // unit-dims, and the parentReshapeOp does an actual reshape. + // If the final tensor_reshape op is expanding, the reshapeOp is + // introducing unit-dims, and the parentReshapeOp does an actual reshape. bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank(); - auto reassociationMaps = isFoldingPattern - ? reshapeOp.getReassociationMaps() - : parentReshapeOp.getReassociationMaps(); - DenseSet conservedDimensions; - for (auto &map : reassociationMaps) { - if (map.getNumResults() == 1) { - conservedDimensions.insert( - map.getResult(0).cast().getPosition()); - } - } - - // Find positions at which the unit-dims exist. - int64_t nonUnitDimPos = 0; - DenseMap nonUnitSrcDims; - ArrayRef nonUnitShape = + ArrayRef expandedShape = isFoldingPattern ? parentSrcType.getShape() : dstType.getShape(); - for (auto shape : enumerate(srcType.getShape())) { - // Case 1 : It is a conserved dimension. - if (conservedDimensions.count(shape.index())) { - nonUnitSrcDims[shape.index()] = nonUnitDimPos++; - continue; + ArrayRef foldedShape = + isFoldingPattern ? dstType.getShape() : parentSrcType.getShape(); + + unsigned expandedDim = 0, foldedDim = 0; + SmallVector, 4> reassociationExprs( + foldedShape.size()); + while (expandedDim < expandedShape.size() && + foldedDim < foldedShape.size()) { + int64_t dstSize = foldedShape[foldedDim]; + int64_t srcSize = expandedShape[expandedDim]; + while (srcSize < dstSize && expandedDim < expandedShape.size()) { + reassociationExprs[foldedDim].push_back( + rewriter.getAffineDimExpr(expandedDim++)); + srcSize *= expandedShape[expandedDim]; } - // Case 2 : Dimensions dont match but the intermediate tensor is unit-dim. - if (shape.value() == 1) - continue; - // Case 3 : Dimensions match, treat it as a non-unit src dim. - if (nonUnitDimPos < static_cast(nonUnitShape.size()) && - nonUnitShape[nonUnitDimPos] == shape.value()) { - nonUnitSrcDims[shape.index()] = nonUnitDimPos++; - continue; + if (srcSize == dstSize) { + reassociationExprs[foldedDim].push_back( + rewriter.getAffineDimExpr(expandedDim++)); + // If the next dim in foldedShape is not 1, treat subsequent dims in + // expandedShape which are 1 to be collapsed. + if (foldedDim == foldedShape.size() - 1 || + foldedShape[foldedDim + 1] != 1) { + while (expandedDim < expandedShape.size() && + expandedShape[expandedDim] == 1) { + reassociationExprs[foldedDim].push_back( + rewriter.getAffineDimExpr(expandedDim++)); + } + } + } else { + return failure(); } - return failure(); + foldedDim++; } + if (expandedDim != expandedShape.size()) + return failure(); - // Compute reassociation maps for the final operation. Use the reassociation - // maps that is actually doing a reshape (and not just introducing - // unit-dims). From these maps, prune the unit-extent dimensions. - for (AffineMap &map : reassociationMaps) { - SmallVector exprs; - exprs.reserve(nonUnitSrcDims.size()); - for (auto result : map.getResults()) { - unsigned dim = result.cast().getPosition(); - if (nonUnitSrcDims.count(dim)) - exprs.push_back(rewriter.getAffineDimExpr(nonUnitSrcDims[dim])); - } - map = AffineMap::get(nonUnitSrcDims.size(), 0, exprs, - rewriter.getContext()); - } + SmallVector reassociationMaps = + llvm::to_vector<4>(llvm::map_range( + reassociationExprs, [&](ArrayRef exprs) -> AffineMap { + return AffineMap::get(expandedShape.size(), 0, exprs, + rewriter.getContext()); + })); rewriter.replaceOpWithNewOp( reshapeOp, dstType, parentReshapeOp.src(), rewriter.getAffineMapArrayAttr(reassociationMaps)); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index 06e56c5cb7d2a2..1793d2b59b706d 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -240,3 +240,19 @@ func @fold_reshape(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32> : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> return %1 : tensor<4x512x1x512x4xf32> } + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @fold_reshape +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]] +// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> +func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32> +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : tensor<2xf32> into tensor<2x1x1xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)> + ] : tensor<2x1x1xf32> into tensor<2x1xf32> + return %1 : tensor<2x1xf32> +}