Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// tile factors.
DenseMap<int64_t, OpFoldResult> getDimAndTileMapping();

/// Return the tile sizes as OpFoldResult.
// TODO: Return the folded result.
/// Return the tile sizes as OpFoldResult. Will return the Value
/// of the constant Op, not the constant Attribute.
/// E.g., for %size = arith.constant 1 : i32 will return %size,
/// not 1.
SmallVector<OpFoldResult> getMixedTiles();

/// Return the tile sizes as `int64_t`. If a tile size is dynamic
Expand Down
64 changes: 29 additions & 35 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();

Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
int64_t numTiles = destRank - srcRank;
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
int64_t numberOfTiles = innerDimsPos.size();

// 1. Extract the inner tile sizes.
// Where possible, values are replaced with constant attributes (to match the
// behaviour of `getPackOpSourceOrPaddedSource`).
SmallVector<OpFoldResult> tileSizes;
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
if (dimAndTileMapping.count(i)) {
// Rather than taking the tile size as is, extact the actual constant
// value Attribute where possible, e.g.:
// [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
auto [_, tileSize] =
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
tileSizes.push_back(tileSize);
}
}
// 1. Get the input that is going to be packed. If the input requires padding,
// add a padding operation and return that as the input.
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);

// 2. Transpose the input to match the inner tile order:
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
// 1. All outer dims are 1 - the corresponding transposition order doesn't
// - All outer dims are 1 - the corresponding transposition order doesn't
// matter, but requires all dim indices to be present.

// 2.1 Get the permutation for linalg.transpose
SmallVector<int64_t> srcPermForTranspose;
ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
// rank of the inner tiling, correspond to the last `k` indices of the
Expand All @@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// rank of the source tensor. For example if we have a source tensor with
// indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
// indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
if (llvm::is_contained(innerDimPos, i))
if (llvm::is_contained(innerDimsPos, i))
continue;
srcPermForTranspose.push_back(i);
}
srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());

// 2.2 Create the init tensor for linalg.transpose with the correct shape
SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
oneIdxAttr);
shapeForEmptyOp.append(packOp.getMixedTiles());

// getMixedTiles() may contain Values pointing to constant ops, not the
// constant attributes. Replace them with a true OpFoldResult.
llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
[&](OpFoldResult ofr) {
if (auto val = llvm::dyn_cast<Value>(ofr))
return getAsOpFoldResult(val);
return ofr;
});

LDBG() << "Pack permutation: " << packOp;
LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);

// 2.1 Create tensor.empty (init value for TransposeOp)
SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
oneIdxAttr);
transShapeForEmptyOp.append(tileSizes);

applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
srcPermForTranspose);
Value empty =
tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
packOp.getSourceType().getElementType());
Value empty = tensor::EmptyOp::create(
rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());

// 2.2 Create linalg.transpose
// 2.3 Create linalg.transpose
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
srcPermForTranspose);

Expand All @@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
// Outer dims are all 1s!
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
oneIdxAttr);
SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
SmallVector<int64_t> writeShape;

for (auto tileSize : packOp.getMixedTiles()) {
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Linalg/decompose-pack.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,24 @@ func.func @pack_with_adjacent_trailing_dimensions_inner_dims_pos_and_unit_outer(
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
// CHECK: return %[[INSERT]]

// -----

// The following example shows a pack operation where the inner dims
// positions are non-adjacent and non-permuted.
func.func @pack_with_non_adjacent_and_non_permuted_inner_dims(%arg0: tensor<8x1x1x1xf32>, %arg1:tensor<1x1x1x1x8x1xf32>) -> tensor<1x1x1x1x8x1xf32> {
%pack = linalg.pack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [0, 3] inner_tiles = [8, 1] into %arg1: tensor<8x1x1x1xf32> -> tensor<1x1x1x1x8x1xf32>
return %pack : tensor<1x1x1x1x8x1xf32>
}

// CHECK-LABEL: func.func @pack_with_non_adjacent_and_non_permuted_inner_dims
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x1xf32>
// CHECK: %[[TRANSP:.+]] = linalg.transpose
// CHECK-SAME: ins(%[[SRC]] : tensor<8x1x1x1xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x1xf32>)
// CHECK-SAME: permutation = [1, 2, 0, 3]
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 1] [1, 1, 1, 1, 1, 1] : tensor<1x1x8x1xf32> into tensor<1x1x1x1x8x1xf32>
// CHECK: return %[[INSERT]]