diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 47145e36c55cf..dc132b22c7c94 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" namespace mlir { #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION @@ -29,10 +30,66 @@ using namespace mlir::linalg; namespace { +// The struct contains the infomation about mapping packing information to +// the iteration domain of Linalg ops. +struct PackInfo { + int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; + // InnerDimsPos on iteration domain, which follows the order in pack ops. + SmallVector tiledDimsPos; + // The sizes of tiling data dimensions on iteration domain. + llvm::DenseMap domainDimAndTileMapping; + // The mapping from a dimension of iteration domain to the corresponding inner + // tiling dimension on iteration domain. + llvm::DenseMap tileToPointMapping; + // The permutation of outer dims (on domain). + SmallVector outerDimsOnDomainPerm; + Optional paddingValue; +}; + +static PackInfo getPackingInfoFromConsumer( + AffineMap indexingMap, ArrayRef innerTileSizes, + ArrayRef innerDimsPos, ArrayRef outerDimsPerm, + Optional paddingValue = llvm::None) { + LLVM_DEBUG( + { llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; }); + PackInfo packInfo; + packInfo.paddingValue = paddingValue; + int64_t origNumDims = indexingMap.getNumDims(); + SmallVector exprs(indexingMap.getResults()); + for (auto [index, innerDimPos, tileSize] : + llvm::zip_equal(llvm::seq(0, innerDimsPos.size()), + innerDimsPos, innerTileSizes)) { + int64_t domainDimPos = + exprs[innerDimPos].cast().getPosition(); + packInfo.tiledDimsPos.push_back(domainDimPos); + packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; + packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; + LLVM_DEBUG({ + llvm::dbgs() << "map innerDimPos=" << innerDimPos + << " to iteration dimension (d" << domainDimPos << ", d" + << packInfo.tileToPointMapping[domainDimPos] + << "), which has size=(" + << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n"; + }); + } + + for (auto dim : outerDimsPerm) + packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim)); + if (!packInfo.outerDimsOnDomainPerm.empty()) { + LLVM_DEBUG({ + llvm::dbgs() << "map outer dimsDimsPerm to "; + for (auto dim : packInfo.outerDimsOnDomainPerm) + llvm::dbgs() << dim << " "; + llvm::dbgs() << "\n"; + }); + } + + return packInfo; +} + /// Returns a tuple for packed operand and indexing_map with the assumptions: /// 1) The generic op is the producer of the pack op. /// 2) The generic op has only one result. -/// 3) The indexing map of the output operand is identity. /// If the operand is a scalar or packing dimensions are all irrelevant to the /// operand, the opreand and the updated indexing map will be returned. /// Otherwise, it returns the packed operand and the updated indexing map. E.g., @@ -62,62 +119,57 @@ namespace { /// inner_tiles = [8] /// into %init : tensor -> tensor static std::tuple -getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, - tensor::PackOp packOp, GenericOp genericOp, - OpOperand *opOperand) { - int numOrigLoops = genericOp.getNumLoops(); - int64_t numInnerLoops = packOp.getInnerDimsPos().size(); +getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, + GenericOp genericOp, OpOperand *opOperand) { + int64_t numOrigLoops = genericOp.getNumLoops(); + int64_t numInnerLoops = packInfo.getNumTiledLoops(); int64_t numLoops = numOrigLoops + numInnerLoops; AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); + llvm::DenseMap domainDimToOperandDim; SmallVector exprs(origIndexingMap.getResults()); - if (genericOp.isScalar(opOperand)) - return std::make_tuple( - opOperand->get(), - AffineMap::get(numLoops, 0, exprs, packOp.getContext())); - - llvm::SetVector innerDimsPosSet(packOp.getInnerDimsPos().begin(), - packOp.getInnerDimsPos().end()); - // Mapping from AffinDimExpr of indexing maps to the operand shape dimension. - DenseMap iterMapToDim; - for (auto [index, expr] : llvm::enumerate(origIndexingMap.getResults())) { + return std::make_tuple(opOperand->get(), + AffineMap::get(numLoops, 0, exprs, b.getContext())); + + // Step 1. Construct the information of packing data dimensions; append inner + // dimensions to the indexing maps for the operand. + for (auto [index, expr] : llvm::enumerate(exprs)) { int64_t dimPos = expr.cast().getPosition(); - if (!innerDimsPosSet.contains(dimPos)) - continue; - iterMapToDim[dimPos] = index; + domainDimToOperandDim[dimPos] = index; } - - // Construct the information of packing data dimensions and new indexing maps - // for the operand. SmallVector innerDimsPos; SmallVector innerTileSizes; - for (auto [index, value] : llvm::enumerate( - llvm::zip(packOp.getInnerDimsPos(), packOp.getMixedTiles()))) { - int64_t dimPos = std::get<0>(value); - if (!iterMapToDim.count(dimPos)) + for (auto dimPos : packInfo.tiledDimsPos) { + if (!domainDimToOperandDim.count(dimPos)) continue; - innerDimsPos.push_back(iterMapToDim[dimPos]); - innerTileSizes.push_back(std::get<1>(value)); - exprs.push_back(b.getAffineDimExpr(numOrigLoops + index)); + int64_t index = domainDimToOperandDim[dimPos]; + innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); + innerDimsPos.push_back(index); + exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); } - auto indexingMap = AffineMap::get(numLoops, 0, exprs, packOp.getContext()); + // Step 2. Fold transpose variants (i.e., outerDimsPerm) into generic op. + // TODO: should we propagate the permutation of outer dims to the pack op? SmallVector outerDimsPerm; - for (auto outDim : packOp.getOuterDimsPerm()) { - if (!iterMapToDim.count(outDim)) - continue; - outerDimsPerm.push_back(iterMapToDim[outDim]); + if (!packInfo.outerDimsOnDomainPerm.empty()) { + SmallVector inversedOuterPerm = + invertPermutationVector(packInfo.outerDimsOnDomainPerm); + for (auto i : llvm::seq(0, origIndexingMap.getNumResults())) { + int64_t dimPos = exprs[i].cast().getPosition(); + exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); + } } + auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); // The operand does not have dimensions that relates to pack op. - if (innerDimsPos.empty() && outerDimsPerm.empty()) + if (innerDimsPos.empty()) return std::make_tuple(opOperand->get(), indexingMap); auto empty = tensor::PackOp::createDestinationTensor( b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); auto packedOperand = b.create( loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, - packOp.getPaddingValue(), outerDimsPerm); + packInfo.paddingValue, outerDimsPerm); return std::make_tuple(packedOperand, indexingMap); } @@ -187,34 +239,45 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter, return failure(); OpOperand *opOperand = genericOp.getDpsInitOperand(0); - // TODO: Add support for all permutation indexing maps. - if (!genericOp.getMatchingIndexingMap(opOperand).isIdentity()) - return rewriter.notifyMatchFailure( - packOp, "the result of generic op does not have identity indexing_map"); + auto packInfo = getPackingInfoFromConsumer( + genericOp.getMatchingIndexingMap(opOperand), packOp.getMixedTiles(), + packOp.getInnerDimsPos(), packOp.getOuterDimsPerm(), + packOp.getPaddingValue()); Location loc = packOp.getLoc(); SmallVector inputOperands; SmallVector indexingMaps; for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( - rewriter, loc, packOp, genericOp, inputOperand); + rewriter, loc, packInfo, genericOp, inputOperand); inputOperands.push_back(packedOperand); indexingMaps.push_back(packedIndexingMap); } int64_t numLoops = genericOp.getNumLoops(); - int64_t numInnerLoops = packOp.getInnerDimsPos().size(); + int64_t numInnerLoops = packInfo.getNumTiledLoops(); int64_t newNumLoops = numLoops + numInnerLoops; SmallVector iterTypes = genericOp.getIteratorTypesArray(); iterTypes.append(numInnerLoops, utils::IteratorType::parallel); + // Rebuild the indexing map for the corresponding init operand. + auto [packedOutOperand, packedOutIndexingMap] = + getOrCreatePackedViewOfOperand(rewriter, loc, packInfo, genericOp, + opOperand); SmallVector outExprs( - genericOp.getMatchingIndexingMap(opOperand).getResults()); + packedOutIndexingMap.getResults().drop_back(numInnerLoops)); + // Apply transpose to the indexing map, because we'll replace the init operand + // with the destination of pack op. + auto outerDimsPerm = packOp.getOuterDimsPerm(); + if (!outerDimsPerm.empty()) { + applyPermutationToVector(outExprs, outerDimsPerm); + } for (int i = 0; i < numInnerLoops; ++i) outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i)); - indexingMaps.push_back( - AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext())); + AffineMap outMap = + AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext()); + indexingMaps.push_back(outMap); auto newGenericOp = rewriter.create( loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps, diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index a5488d28b20c9..bb84272bf8b02 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -96,16 +96,17 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: ten into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32> return %pack : tensor<16x4x32x16xi32> } -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func.func @elem_pack_transpose_outer_dims // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x32x16xi32> // CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] -// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[ARG0_EMPTY]] : tensor<128x256xi32> -> tensor<4x16x32x16xi32> // CHECK: %[[ELEM:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[PACK_ARG0]] // CHECK-SAME: outs(%[[DEST]] @@ -130,16 +131,17 @@ func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>, into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32> return %pack : tensor<16x4x16x32xi32> } -// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK: func.func @elem_pack_transpose_inner_and_outer_dims // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32> +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32> // CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] -// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32] +// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] // CHECK-SAME: into %[[ARG0_EMPTY]] // CHECK: %[[ELEM:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[PACK_ARG0]] // CHECK-SAME: outs(%[[DEST]] @@ -200,6 +202,37 @@ func.func @dynamic_broadcast_pack(%arg0: tensor, %arg1: tensor, %d // ----- +#map = affine_map<(d0, d1, d2, d3) -> (d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @elem_pack_transpose_inner_and_outer_dims2(%arg0: tensor<64xf32>, %dest: tensor<1x2x56x57x32xf32>) -> tensor<1x2x56x57x32xf32> { + %0 = tensor.empty() : tensor<1x56x57x64xf32> + %1 = linalg.generic { + indexing_maps = [#map, #map1], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<64xf32>) + outs(%0 : tensor<1x56x57x64xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<1x56x57x64xf32> + %2 = tensor.pack %1 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %dest : tensor<1x56x57x64xf32> -> tensor<1x2x56x57x32xf32> + return %2 : tensor<1x2x56x57x32xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK: func.func @elem_pack_transpose_inner_and_outer_dims2 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<2x32xf32> +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: ins(%[[PACKED_ARG0]] +// CHECK-SAME: outs(%[[DEST]] + +// ----- + #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> #map2 = affine_map<(d0, d1) -> (d1)> @@ -225,6 +258,53 @@ func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100x into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32> return %4 : tensor<100x200x4x16x16x32xi32> } -// CHECK: func.func @transpose_pack -// CHECK: linalg.generic -// CHECK: tensor.pack +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d5)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d1, d3, d4, d5)> +// CHECK: func.func @transpose_pack +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<100x4x200x16x16x32xi32> +// CHECK: %[[PACKED_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [3, 1] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ARG2_EMPTY:.+]] = tensor.empty() : tensor<4x32xi32> +// CHECK: %[[PACKED_ARG2:.+]] = tensor.pack %[[ARG2]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [32] +// CHECK-SAME: into %[[ARG2_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: ins(%[[PACKED_ARG0]], %[[ARG1]], %[[PACKED_ARG2]] +// CHECK-SAME: outs(%[[DEST]] + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> +#map2 = affine_map<(d0, d1) -> (d1)> +func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<200x4x16x100x16x32xi32>) -> tensor<200x4x16x100x16x32xi32> +{ + %init_transpose = tensor.empty() : tensor<100x200x128x256xi32> + %transpose = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0)>, + affine_map<(d0, d1, d2, d3) -> (d1)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>) + outs(%init_transpose : tensor<100x200x128x256xi32>) { + ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): + %0 = arith.addi %b0, %b1 : i32 + %1 = arith.addi %0, %b2 : i32 + linalg.yield %1 : i32 + } -> tensor<100x200x128x256xi32> + %4 = tensor.pack %transpose + outer_dims_perm = [1, 2, 3, 0] + inner_dims_pos = [3, 2] + inner_tiles = [16, 32] + into %dest : tensor<100x200x128x256xi32> -> tensor<200x4x16x100x16x32xi32> + return %4 : tensor<200x4x16x100x16x32xi32> +}