Skip to content

Commit

Permalink
[mlir][linalg] Add control to pad-slice swap pattern
Browse files Browse the repository at this point in the history
The pad-slice swap pattern generates `scf.if` and `tensor.generate`
to guard against zero-sized slices if it cannot prove the slice is
always non-zero. This is safe but quite conservative. It can be
unnecessary for cases where we know by problem definition such cases
does not exist, even if with dynamic shaped ops or unknown tile/slice
sizes, e.g., convolution padding size = 1 with kernel dim size = 3.

So this commit introduces a control to the pattern to specify
whether to generate the if constructs to handle such cases better,
given that once the if constructs is materialized, it's very hard
to analyze and simplify.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D117017
  • Loading branch information
antiagainst committed Feb 16, 2022
1 parent 27cd2a6 commit 0edb412
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 200 deletions.
19 changes: 18 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Expand Up @@ -1399,10 +1399,27 @@ LogicalResult applyStagedPatterns(
/// Rewrite extract_slice(pad_tensor(x)) into pad_tensor(extract_slice(x)).
struct ExtractSliceOfPadTensorSwapPattern
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
/// A function to control pattern application and rewrite logic.
///
/// The function will be given the slice op and should return:
/// - None: to fail the match and not apply the pattern;
/// - true: to apply the pattern with zero slice guard;
/// - false: to apply the pattern without zero slice guard.
///
/// See the documentation for tensor::bubbleUpPadSlice regarding zero slice
/// guard.
using ControlFn = std::function<llvm::Optional<bool>(tensor::ExtractSliceOp)>;

ExtractSliceOfPadTensorSwapPattern(MLIRContext *context,
ControlFn controlFn = nullptr,
PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}

LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override;

private:
ControlFn controlFn;
};

//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h
Expand Up @@ -18,6 +18,32 @@
namespace mlir {
namespace tensor {

class PadOp;

/// Bubbles up a slice of this pad by taking the slice first and then performing
/// the padding. `offsets` and `strides` specifies each dimension's start offset
/// and size for the slice. The slice has unit strides along all dimensions.
///
/// Specifically, this function converts:
/// ```
/// %0 = tensor.pad %source low[...] high[...] { linalg.yield %cst }
/// %1 = <extract-slice> %0 offsets=[...], sizes[...]
/// ```
/// into
/// ```
/// %0 = tensor.extract_slice %source ...
/// %0 = tensor.pad %0 low[...] high[...] { linalg.yield %cst }
/// ```
///
/// If `generateZeroSliceGuard` is true, the generated IR will contain logic
/// to guard against the case that we might take a zero-sized slice from the
/// original source. For such cases, we `tensor.generate` to generate the
/// full tensor.
Operation *bubbleUpPadSlice(OpBuilder &b, tensor::PadOp padOp,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
bool generateZeroSliceGuard = true);

/// Registers external models for Tiling interface for tensor ops.
/// Currently, it registers:
///
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Expand Up @@ -54,6 +54,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRStandardOpsTransforms
MLIRStandardToLLVM
MLIRTensor
MLIRTensorTilingInterfaceImpl
MLIRTensorTransforms
MLIRTransforms
MLIRTransformUtils
Expand Down
26 changes: 15 additions & 11 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand Down Expand Up @@ -911,23 +912,26 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,

LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const {
if (!sliceOp.hasUnitStride())
return failure();

auto padOp = sliceOp.source().getDefiningOp<tensor::PadOp>();
if (!padOp)
return failure();
// Only unit stride supported.
if (!sliceOp.hasUnitStride())
return failure();

TilingInterface tilingInterface =
dyn_cast<TilingInterface>(padOp.getOperation());
bool zeroSliceGuard = true;
if (controlFn) {
if (Optional<bool> control = controlFn(sliceOp))
zeroSliceGuard = control.getValue();
else
return failure();
}

Operation *tiledPadOp =
tilingInterface
.getTiledImplementation(
rewriter, /*dest=*/ValueRange{}, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), /*tileDestOperands=*/false)
.front();
tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), zeroSliceGuard);
// All shapes are static and the data source is actually used. Rewrite into
// pad_tensor(subtensor(x)).
// pad(extract_slice(x)).
rewriter.replaceOp(sliceOp, tiledPadOp->getResults());
return success();
}
Expand Down

0 comments on commit 0edb412

Please sign in to comment.