Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles);

// Same as above function but here dynamic dimensions are assumed
// to require padding.
static bool requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles);

static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1914,9 +1914,12 @@ void populateElementwiseOpsFusionPatterns(
using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>;

/// Patterns to bubble up or down data layout ops across other operations.
/// The function also has an option to allow the patterns to propagate with
/// poison padding if requested by the caller.
void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);
const ControlPropagationFn &controlPackUnPackPropagation,
bool PoisonPaddingOk = false);

/// Patterns to sink extract slice across other operations.
void populateExtractSliceSinkingPatterns(
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5310,6 +5310,32 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
return false;
}

bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles) {
SmallVector<int64_t> outputTileSizes(
outputShape.take_front(inputShape.size()));
if (!outerDimsPerm.empty()) {
assert(outerDimsPerm.size() == outputTileSizes.size() &&
"expected output and outer_dims_perm to have same size");
applyPermutationToVector(outputTileSizes,
invertPermutationVector(outerDimsPerm));
}
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
if (ShapedType::isDynamic(inputShape[pos]) ||
ShapedType::isDynamic(outputTileSizes[pos]))
return true;
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
if (!constantTile)
return true;
if (inputShape[pos] % (*constantTile) != 0)
return true;
}
return false;
}

LogicalResult PackOp::verify() {
if (failed(commonVerifierPackAndUnPackOp(*this)))
return failure();
Expand Down
Loading