Skip to content

Commit

Permalink
[mlir][Transform] Make FuseIntoContainingOp support rank-reducing ext…
Browse files Browse the repository at this point in the history
…ract slices

This fixes an issue where rank-reducing + fusion would not interop properly.

Differential Revision: https://reviews.llvm.org/D139844
  • Loading branch information
nicolasvasilache committed Dec 12, 2022
1 parent cde2cc9 commit 93bbcff
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 7 deletions.
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Expand Up @@ -435,6 +435,15 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
/// Return the dimensions of the source that are dropped in the
/// result when the result is rank-reduced.
llvm::SmallBitVector getDroppedDims();

/// Given a `value`, asserted to be of RankedTensorType, build an
/// ExtractSliceOp that results in a rank-reducing extract to the desired
/// tensor shape and return the new value created.
/// If the shape of `value` is already the `desiredShape`, just return
/// `value`.
/// If the shape of `value` cannot be rank-reduced to `desiredShape`, fail.
static FailureOr<Value> rankReduceIfNeeded(
OpBuilder &b, Location loc, Value value, ArrayRef<int64_t> desiredShape);
}];

let hasCanonicalizer = 1;
Expand Down
19 changes: 17 additions & 2 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Expand Up @@ -20,6 +20,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -299,7 +300,14 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,

// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
sliceOpToTile->getResult(0)
.getType()
.cast<RankedTensorType>()
.getShape());
assert(succeeded(maybeRankReduced) && "unexpected shape");
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
return fusedOp;
}

Expand Down Expand Up @@ -399,7 +407,14 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(

// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
sliceOpToTile->getResult(0)
.getType()
.cast<RankedTensorType>()
.getShape());
assert(succeeded(maybeRankReduced) && "unexpected shape");
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);

// Replace the use in containingOp.
rewriter.updateRootInPlace(containingOp, [&]() {
Expand Down
26 changes: 21 additions & 5 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Expand Up @@ -17,7 +17,9 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Support/MathExtras.h"
Expand Down Expand Up @@ -1754,6 +1756,23 @@ llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
return droppedDims;
}

FailureOr<Value>
ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
ArrayRef<int64_t> desiredShape) {
auto sourceTensorType = value.getType().dyn_cast<RankedTensorType>();
assert(sourceTensorType && "not a ranked tensor type");
auto sourceShape = sourceTensorType.getShape();
if (sourceShape.equals(desiredShape))
return value;
auto maybeRankReductionMask =
mlir::computeRankReductionMask(sourceShape, desiredShape);
if (!maybeRankReductionMask)
return failure();
return createCanonicalRankReducingExtractSliceOp(
b, loc, value,
RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
}

LogicalResult ExtractSliceOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
reifiedReturnShapes.resize(1);
Expand Down Expand Up @@ -2375,7 +2394,6 @@ struct InsertSliceOpSourceCastInserter final
insertSliceOp, cast, insertSliceOp.getDest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides());
cast.getDefiningOp()->getParentOfType<ModuleOp>().dump();
return success();
}
};
Expand Down Expand Up @@ -2475,8 +2493,7 @@ RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,

SmallVector<int64_t, 4> inferredShape;
for (auto i : llvm::seq<unsigned>(0, rank)) {
if (sourceType.isDynamicDim(i) ||
staticLow[i] == ShapedType::kDynamic ||
if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
staticHigh[i] == ShapedType::kDynamic) {
inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
: resultShape[i]);
Expand Down Expand Up @@ -2525,8 +2542,7 @@ void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
// This will grow staticLow and staticHigh with 1 value. If the config is
// dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
// value as well.
dispatchIndexOpFoldResults(low, dynamicLow, staticLow,
ShapedType::kDynamic);
dispatchIndexOpFoldResults(low, dynamicLow, staticLow, ShapedType::kDynamic);
dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh,
ShapedType::kDynamic);
if (!resultType) {
Expand Down
46 changes: 46 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
Expand Up @@ -96,6 +96,52 @@ module {

// -----

module {
func.func @foo(%0: tensor<f32>) -> tensor<f32> {
return %0: tensor<f32>
}

// CHECK-LABEL: func.func @fuse_tileable_op_rank_reducing
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32>
func.func @fuse_tileable_op_rank_reducing(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
%cst = arith.constant 4.200000e+01 : f32
%c0 = arith.constant 0 : index
%0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
%d0 = tensor.dim %arg1, %c0 : tensor<?xf32>

// CHECK: scf.foreach_thread {{.*}} -> (tensor<?xf32>) {
%2 = scf.foreach_thread (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor<?xf32>) {
%5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor<?xf32> to tensor<f32>

// CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor<?xf32> to tensor<1xf32>
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32>
// CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor<f32>
// CHECK: func.call @foo(%{{.*}}) : (tensor<f32>) -> tensor<f32>
%7 = func.call @foo(%5) : (tensor<f32>) -> tensor<f32>

scf.foreach_thread.perform_concurrently {
// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}] [1] [1] : tensor<f32> into tensor<?xf32>
tensor.parallel_insert_slice %7 into %o[%arg3] [1] [1] : tensor<f32> into tensor<?xf32>
}
}
// CHECK: }
func.return %2 : tensor<?xf32>
}

transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.fill"]} in %arg1
%1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1

// linalg.fill is tileable. The op is tiled and fused.
transform.structured.fuse_into_containing_op %0 into %1
}
}

// -----

#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
Expand Down

0 comments on commit 93bbcff

Please sign in to comment.