Skip to content

Commit

Permalink
[mlir][tensor][linalg] Enhance pack op propagation across generic ops.
Browse files Browse the repository at this point in the history
Considering the case that generic + pack (with outer_dim_perms), the
truth is that it is equipvelent to generic + pack + transpose. There are
two steps to bubble up the pack op accross the generic op.

Step 1. swap generic + pack -> pack + generic.

In this step, we can bind the packing information to dimensions of
iteration domain. With the information, we can pack the operands with
corresponding data tile sizes; the packed inner dimensions will be
appended to the indexing_maps. Note that the outer dimensions of
indexing maps are not changed at all.

Step 2. Fold the transpose into generic op.

The step two is just updating the indexing map, so we do not have to
handle outer_dim_perms anymore.

There could be step 3 to extract the transpose op out (i.e., generic ->
transpose + generic), then we can fold the transpose into the pack op.
This step is not done in the revision.

Co-authored-by: Lorenzo Chelini <l.chelini@icloud.com>

Reviewed By: chelini

Differential Revision: https://reviews.llvm.org/D139680
  • Loading branch information
hanhanW committed Dec 13, 2022
1 parent 6e6fe27 commit d38d606
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 57 deletions.
153 changes: 108 additions & 45 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
Expand All @@ -29,10 +30,66 @@ using namespace mlir::linalg;

namespace {

// The struct contains the infomation about mapping packing information to
// the iteration domain of Linalg ops.
struct PackInfo {
int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
// InnerDimsPos on iteration domain, which follows the order in pack ops.
SmallVector<int64_t> tiledDimsPos;
// The sizes of tiling data dimensions on iteration domain.
llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
// The mapping from a dimension of iteration domain to the corresponding inner
// tiling dimension on iteration domain.
llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
// The permutation of outer dims (on domain).
SmallVector<int64_t> outerDimsOnDomainPerm;
Optional<Value> paddingValue;
};

static PackInfo getPackingInfoFromConsumer(
AffineMap indexingMap, ArrayRef<OpFoldResult> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
Optional<Value> paddingValue = llvm::None) {
LLVM_DEBUG(
{ llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; });
PackInfo packInfo;
packInfo.paddingValue = paddingValue;
int64_t origNumDims = indexingMap.getNumDims();
SmallVector<AffineExpr> exprs(indexingMap.getResults());
for (auto [index, innerDimPos, tileSize] :
llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
innerDimsPos, innerTileSizes)) {
int64_t domainDimPos =
exprs[innerDimPos].cast<AffineDimExpr>().getPosition();
packInfo.tiledDimsPos.push_back(domainDimPos);
packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
LLVM_DEBUG({
llvm::dbgs() << "map innerDimPos=" << innerDimPos
<< " to iteration dimension (d" << domainDimPos << ", d"
<< packInfo.tileToPointMapping[domainDimPos]
<< "), which has size=("
<< packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
});
}

for (auto dim : outerDimsPerm)
packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim));
if (!packInfo.outerDimsOnDomainPerm.empty()) {
LLVM_DEBUG({
llvm::dbgs() << "map outer dimsDimsPerm to ";
for (auto dim : packInfo.outerDimsOnDomainPerm)
llvm::dbgs() << dim << " ";
llvm::dbgs() << "\n";
});
}

return packInfo;
}

/// Returns a tuple for packed operand and indexing_map with the assumptions:
/// 1) The generic op is the producer of the pack op.
/// 2) The generic op has only one result.
/// 3) The indexing map of the output operand is identity.
/// If the operand is a scalar or packing dimensions are all irrelevant to the
/// operand, the opreand and the updated indexing map will be returned.
/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
Expand Down Expand Up @@ -62,62 +119,57 @@ namespace {
/// inner_tiles = [8]
/// into %init : tensor<?xf32> -> tensor<?x8xf32>
static std::tuple<Value, AffineMap>
getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc,
tensor::PackOp packOp, GenericOp genericOp,
OpOperand *opOperand) {
int numOrigLoops = genericOp.getNumLoops();
int64_t numInnerLoops = packOp.getInnerDimsPos().size();
getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
GenericOp genericOp, OpOperand *opOperand) {
int64_t numOrigLoops = genericOp.getNumLoops();
int64_t numInnerLoops = packInfo.getNumTiledLoops();
int64_t numLoops = numOrigLoops + numInnerLoops;
AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
SmallVector<AffineExpr> exprs(origIndexingMap.getResults());

if (genericOp.isScalar(opOperand))
return std::make_tuple(
opOperand->get(),
AffineMap::get(numLoops, 0, exprs, packOp.getContext()));

llvm::SetVector<int64_t> innerDimsPosSet(packOp.getInnerDimsPos().begin(),
packOp.getInnerDimsPos().end());
// Mapping from AffinDimExpr of indexing maps to the operand shape dimension.
DenseMap<int64_t, int64_t> iterMapToDim;
for (auto [index, expr] : llvm::enumerate(origIndexingMap.getResults())) {
return std::make_tuple(opOperand->get(),
AffineMap::get(numLoops, 0, exprs, b.getContext()));

// Step 1. Construct the information of packing data dimensions; append inner
// dimensions to the indexing maps for the operand.
for (auto [index, expr] : llvm::enumerate(exprs)) {
int64_t dimPos = expr.cast<AffineDimExpr>().getPosition();
if (!innerDimsPosSet.contains(dimPos))
continue;
iterMapToDim[dimPos] = index;
domainDimToOperandDim[dimPos] = index;
}

// Construct the information of packing data dimensions and new indexing maps
// for the operand.
SmallVector<int64_t> innerDimsPos;
SmallVector<OpFoldResult> innerTileSizes;
for (auto [index, value] : llvm::enumerate(
llvm::zip(packOp.getInnerDimsPos(), packOp.getMixedTiles()))) {
int64_t dimPos = std::get<0>(value);
if (!iterMapToDim.count(dimPos))
for (auto dimPos : packInfo.tiledDimsPos) {
if (!domainDimToOperandDim.count(dimPos))
continue;
innerDimsPos.push_back(iterMapToDim[dimPos]);
innerTileSizes.push_back(std::get<1>(value));
exprs.push_back(b.getAffineDimExpr(numOrigLoops + index));
int64_t index = domainDimToOperandDim[dimPos];
innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
innerDimsPos.push_back(index);
exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
}
auto indexingMap = AffineMap::get(numLoops, 0, exprs, packOp.getContext());

// Step 2. Fold transpose variants (i.e., outerDimsPerm) into generic op.
// TODO: should we propagate the permutation of outer dims to the pack op?
SmallVector<int64_t> outerDimsPerm;
for (auto outDim : packOp.getOuterDimsPerm()) {
if (!iterMapToDim.count(outDim))
continue;
outerDimsPerm.push_back(iterMapToDim[outDim]);
if (!packInfo.outerDimsOnDomainPerm.empty()) {
SmallVector<int64_t> inversedOuterPerm =
invertPermutationVector(packInfo.outerDimsOnDomainPerm);
for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
int64_t dimPos = exprs[i].cast<AffineDimExpr>().getPosition();
exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
}
}
auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());

// The operand does not have dimensions that relates to pack op.
if (innerDimsPos.empty() && outerDimsPerm.empty())
if (innerDimsPos.empty())
return std::make_tuple(opOperand->get(), indexingMap);

auto empty = tensor::PackOp::createDestinationTensor(
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
auto packedOperand = b.create<tensor::PackOp>(
loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
packOp.getPaddingValue(), outerDimsPerm);
packInfo.paddingValue, outerDimsPerm);
return std::make_tuple(packedOperand, indexingMap);
}

Expand Down Expand Up @@ -187,34 +239,45 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
return failure();

OpOperand *opOperand = genericOp.getDpsInitOperand(0);
// TODO: Add support for all permutation indexing maps.
if (!genericOp.getMatchingIndexingMap(opOperand).isIdentity())
return rewriter.notifyMatchFailure(
packOp, "the result of generic op does not have identity indexing_map");
auto packInfo = getPackingInfoFromConsumer(
genericOp.getMatchingIndexingMap(opOperand), packOp.getMixedTiles(),
packOp.getInnerDimsPos(), packOp.getOuterDimsPerm(),
packOp.getPaddingValue());

Location loc = packOp.getLoc();
SmallVector<Value> inputOperands;
SmallVector<AffineMap> indexingMaps;
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
rewriter, loc, packOp, genericOp, inputOperand);
rewriter, loc, packInfo, genericOp, inputOperand);
inputOperands.push_back(packedOperand);
indexingMaps.push_back(packedIndexingMap);
}

int64_t numLoops = genericOp.getNumLoops();
int64_t numInnerLoops = packOp.getInnerDimsPos().size();
int64_t numInnerLoops = packInfo.getNumTiledLoops();
int64_t newNumLoops = numLoops + numInnerLoops;
SmallVector<utils::IteratorType> iterTypes =
genericOp.getIteratorTypesArray();
iterTypes.append(numInnerLoops, utils::IteratorType::parallel);

// Rebuild the indexing map for the corresponding init operand.
auto [packedOutOperand, packedOutIndexingMap] =
getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp,
opOperand);
SmallVector<AffineExpr> outExprs(
genericOp.getMatchingIndexingMap(opOperand).getResults());
packedOutIndexingMap.getResults().drop_back(numInnerLoops));
// Apply transpose to the indexing map, because we'll replace the init operand
// with the destination of pack op.
auto outerDimsPerm = packOp.getOuterDimsPerm();
if (!outerDimsPerm.empty()) {
applyPermutationToVector<AffineExpr>(outExprs, outerDimsPerm);
}
for (int i = 0; i < numInnerLoops; ++i)
outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i));
indexingMaps.push_back(
AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext()));
AffineMap outMap =
AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext());
indexingMaps.push_back(outMap);

auto newGenericOp = rewriter.create<linalg::GenericOp>(
loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps,
Expand Down
104 changes: 92 additions & 12 deletions mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Expand Up @@ -96,16 +96,17 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: ten
into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
return %pack : tensor<16x4x32x16xi32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func.func @elem_pack_transpose_outer_dims
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16xi32>
// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
// CHECK-SAME: into %[[ARG0_EMPTY]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
// CHECK-SAME: into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<4x16x32x16xi32>
// CHECK: %[[ELEM:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[PACK_ARG0]]
// CHECK-SAME: outs(%[[DEST]]
Expand All @@ -130,16 +131,17 @@ func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>,
into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32>
return %pack : tensor<16x4x16x32xi32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func.func @elem_pack_transpose_inner_and_outer_dims
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32>
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32>
// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32]
// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32]
// CHECK-SAME: into %[[ARG0_EMPTY]]
// CHECK: %[[ELEM:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[PACK_ARG0]]
// CHECK-SAME: outs(%[[DEST]]
Expand Down Expand Up @@ -200,6 +202,37 @@ func.func @dynamic_broadcast_pack(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %d

// -----

#map = affine_map<(d0, d1, d2, d3) -> (d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @elem_pack_transpose_inner_and_outer_dims2(%arg0: tensor<64xf32>, %dest: tensor<1x2x56x57x32xf32>) -> tensor<1x2x56x57x32xf32> {
%0 = tensor.empty() : tensor<1x56x57x64xf32>
%1 = linalg.generic {
indexing_maps = [#map, #map1],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%arg0 : tensor<64xf32>)
outs(%0 : tensor<1x56x57x64xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<1x56x57x64xf32>
%2 = tensor.pack %1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %dest : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32>
return %2 : tensor<1x2x56x57x32xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK: func.func @elem_pack_transpose_inner_and_outer_dims2
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<2x32xf32>
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32]
// CHECK-SAME: into %[[ARG0_EMPTY]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK-SAME: outs(%[[DEST]]

// -----

#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#map2 = affine_map<(d0, d1) -> (d1)>
Expand All @@ -225,6 +258,53 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x
into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32>
return %4 : tensor<100x200x4x16x16x32xi32>
}
// CHECK: func.func @transpose_pack
// CHECK: linalg.generic
// CHECK: tensor.pack
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)>
// CHECK: func.func @transpose_pack
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32>
// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]]
// CHECK-SAME: inner_dims_pos = [3, 1] inner_tiles = [16, 32]
// CHECK-SAME: into %[[ARG0_EMPTY]]
// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32>
// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]]
// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32]
// CHECK-SAME: into %[[ARG2_EMPTY]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]]
// CHECK-SAME: outs(%[[DEST]]

// -----

#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#map2 = affine_map<(d0, d1) -> (d1)>
func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32>
{
%init_transpose = tensor.empty() : tensor<100x200x128x256xi32>
%transpose = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0)>,
affine_map<(d0, d1, d2, d3) -> (d1)>,
affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>)
outs(%init_transpose : tensor<100x200x128x256xi32>) {
^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32):
%0 = arith.addi %b0, %b1 : i32
%1 = arith.addi %0, %b2 : i32
linalg.yield %1 : i32
} -> tensor<100x200x128x256xi32>
%4 = tensor.pack %transpose
outer_dims_perm = [1, 2, 3, 0]
inner_dims_pos = [3, 2]
inner_tiles = [16, 32]
into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32>
return %4 : tensor<200x4x16x100x16x32xi32>
}

2 comments on commit d38d606

@kazutakahirata
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed a deprecation warning with 19c37da2109ff0b3903e05ed443d51b83521575a.

@hanhanW
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed a deprecation warning with 19c37da2109ff0b3903e05ed443d51b83521575a.

ah thanks!

Please sign in to comment.