Skip to content

Commit

Permalink
[mlir][linalg] Add constant padding helper to PadTensorOp
Browse files Browse the repository at this point in the history
* Add a helper function that returns the constant padding value (if applicable).
* Remove existing getConstantYieldValueFromBlock function, which does almost the same.
* Adapted from D103243.

Differential Revision: https://reviews.llvm.org/D104004
  • Loading branch information
matthias-springer committed Jun 14, 2021
1 parent 594febf commit bf5d309
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def Linalg_PadTensorOp : Linalg_Op<"pad_tensor",
Type type, Value source, Value pad, ArrayRef<OpFoldResult> low,
ArrayRef<OpFoldResult> high, Location loc, OpBuilder & builder);

// Return the pad value if it is a constant. Return null value otherwise.
Value getConstantPaddingValue();

// Return a vector of all the static or dynamic values (low/high padding) of
// the op.
inline SmallVector<OpFoldResult> getMixedPadImpl(ArrayAttr staticAttrs,
Expand Down
24 changes: 24 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,30 @@ void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<FoldStaticZeroPadding>(context);
}

/// Return the padding value of the PadTensorOp if it constant. In this context,
/// "constant" means an actual constant or "defined outside of the block".
///
/// Values are considered constant in three cases:
/// - A ConstantLike value.
/// - A basic block argument from a different block.
/// - A value defined outside of the block.
///
/// If the padding value is not constant, an empty Value is returned.
Value PadTensorOp::getConstantPaddingValue() {
auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
if (!yieldOp || yieldOp.values().size() != 1)
return {};
Value padValue = yieldOp.values().front();
// Check if yield value is a constant.
if (matchPattern(padValue, m_Constant()))
return padValue;
// Check if yield value is defined inside the PadTensorOp block.
if (padValue.getParentBlock() == &getRegion().front())
return {};
// Else: Yield value defined outside of the PadTensorOp block.
return padValue;
}

//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
Expand Down
27 changes: 1 addition & 26 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,31 +650,6 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
// Misc. vectorization patterns.
//----------------------------------------------------------------------------//

/// Given a block, return the Value that the block yields if that Value is
/// constant. In this context, "constant" means "defined outside of the block".
/// Should not be called on blocks that yield more than one value.
///
/// Values are considered constant in two cases:
/// - A basic block argument from a different block.
/// - A value defined outside of the block.
///
/// If the yielded value is not constant, an empty Value is returned.
static Value getConstantYieldValueFromBlock(Block &block) {
auto yieldOp = cast<YieldOp>(block.getTerminator());
assert(yieldOp.getNumOperands() == 1 && "expected single operand yield");
Value result = yieldOp.values().front();
Operation *definingOp = result.getDefiningOp();

// Check if yield value is defined inside the block.
if (definingOp && definingOp->getBlock() == &block)
return Value();
// Check if the yield value is a BB arg of the block.
if (!definingOp && result.cast<BlockArgument>().getOwner() == &block)
return Value();

return result;
}

/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
/// TransferWriteOp. For now, this only applies when all low and high paddings
/// are determined to be zero.
Expand All @@ -693,7 +668,7 @@ struct GenericPadTensorOpVectorizationPattern
// High padding must be static 0.
if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure();
// Pad value must be a constant.
auto padValue = getConstantYieldValueFromBlock(padOp.region().front());
auto padValue = padOp.getConstantPaddingValue();
if (!padValue) return failure();

// Bail on non-static shapes.
Expand Down

0 comments on commit bf5d309

Please sign in to comment.