Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tensor] Fold producer linalg transpose with consumer tensor pack #75658

Merged
merged 2 commits into from Jan 10, 2024

Conversation

meshtag
Copy link
Contributor

@meshtag meshtag commented Dec 15, 2023

Successor to #74206

Partial fix to iree-org/iree#15367

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 15, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Prathamesh Tagore (meshtag)

Changes

Successor to #74206

Partial fix to iree-org/iree#15367


Full diff: https://github.com/llvm/llvm-project/pull/75658.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp (+98-30)
  • (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+115)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
index e4509b331beeac..2c45cd3500fa94 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp
@@ -21,6 +21,57 @@ static bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
       ofrs, [&](OpFoldResult ofr) { return isConstantIntValue(ofr, value); });
 }
 
+/// Helper function to generate an equivalent permutation map for
+/// `linalg.transpose` and `tensor.pack` which will be used after their folding
+/// into a `tensor.pack`.
+static bool getRemappedPermutationForTransposeAndPack(
+    PackOp packOp, linalg::TransposeOp transposeOp,
+    SmallVector<int64_t> &newOuterDimsPermVec,
+    SmallVector<int64_t> &newInnerDimsPosVec,
+    SmallVector<OpFoldResult> &newMixedInnerTilesVec,
+    bool isTransposeProducer) {
+  bool foldingPossible = true;
+  auto innerDimsPos = packOp.getInnerDimsPos();
+  auto mixedInnerTiles = packOp.getMixedTiles();
+  auto outerDimsPerm = packOp.getOuterDimsPerm();
+  auto transposePerm = transposeOp.getPermutation();
+  int64_t srcRank = packOp.getSourceRank();
+
+  // Note: if isTransposeProducer = true, transposePerm.size() = srcRank, else
+  // transposePerm.size() > srcRank
+
+  // Process transpose operation for non-tiled outer dimensions
+  for (unsigned int i = 0; i < srcRank; ++i) {
+    int64_t remappedPosition =
+        isTransposeProducer ? (!outerDimsPerm.empty() ? outerDimsPerm[i] : i)
+                            : transposePerm[i];
+
+    if (remappedPosition >= srcRank) {
+      foldingPossible = false;
+      return foldingPossible;
+    }
+
+    remappedPosition =
+        isTransposeProducer
+            ? transposePerm[remappedPosition]
+            : (!outerDimsPerm.empty() ? outerDimsPerm[remappedPosition]
+                                      : remappedPosition);
+
+    newOuterDimsPermVec.push_back(remappedPosition);
+  }
+
+  // Process transpose operation for tiled inner dimensions
+  for (unsigned int i = srcRank; i < srcRank + mixedInnerTiles.size(); ++i) {
+    int64_t remappedPosition =
+        isTransposeProducer ? i - srcRank : transposePerm[i] - srcRank;
+
+    newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
+    newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
+  }
+
+  return foldingPossible;
+}
+
 /// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
 /// the pad op has zero low paddings, or if `pack` has no padding values.
 struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -96,39 +147,19 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     if (!packOp)
       return failure();
 
-    auto innerDimsPos = packOp.getInnerDimsPos();
-    auto mixedInnerTiles = packOp.getMixedTiles();
-    auto outerDimsPerm = packOp.getOuterDimsPerm();
-    auto transposePerm = transposeOp.getPermutation();
     SmallVector<int64_t> newOuterDimsPermVec;
     SmallVector<int64_t> newInnerDimsPosVec;
     SmallVector<OpFoldResult> newMixedInnerTilesVec;
-    int64_t srcRank = packOp.getSourceRank();
-
-    // Process transpose operation for non-tiled outer dimensions
-    for (unsigned int i = 0; i < srcRank; ++i) {
-      int64_t remappedPosition = transposePerm[i];
-
-      // If tensor.pack has outer_dims_perm attribute, then consider it during
-      // index remapping.
-      if (!outerDimsPerm.empty()) {
-        if (transposePerm[i] >= srcRank) {
-          return rewriter.notifyMatchFailure(
-              transposeOp,
-              "Cannot fold in tensor.pack if a tile dimension was transposed "
-              "with a non-tile dimension in linalg.transpose.");
-        }
-        remappedPosition = outerDimsPerm[remappedPosition];
-      }
-
-      newOuterDimsPermVec.push_back(remappedPosition);
-    }
 
-    // Process transpose operation for tiled inner dimensions
-    for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
-      int64_t remappedPosition = transposePerm[i] - srcRank;
-      newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
-      newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
+    bool foldingPossible = getRemappedPermutationForTransposeAndPack(
+        packOp, transposeOp, newOuterDimsPermVec, newInnerDimsPosVec,
+        newMixedInnerTilesVec, /*isTransposeProducer*/ false);
+
+    if (!foldingPossible) {
+      return rewriter.notifyMatchFailure(
+          transposeOp,
+          "Cannot fold in tensor.pack if a tile dimension was transposed "
+          "with a non-tile dimension in linalg.transpose.");
     }
 
     Value output = packOp.createDestinationTensor(
@@ -142,11 +173,48 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
     return success();
   }
 };
+
+/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
+/// semantics.
+struct FoldConsumerPackWithProducerLinalgTransposeOp
+    : public OpRewritePattern<PackOp> {
+  using OpRewritePattern<PackOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(PackOp packOp,
+                                PatternRewriter &rewriter) const override {
+    auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
+
+    if (!transposeOp)
+      return failure();
+
+    SmallVector<int64_t> newOuterDimsPermVec;
+    SmallVector<int64_t> newInnerDimsPosVec;
+    SmallVector<OpFoldResult> newMixedInnerTilesVec;
+
+    bool foldingPossible = getRemappedPermutationForTransposeAndPack(
+        packOp, transposeOp, newOuterDimsPermVec, newInnerDimsPosVec,
+        newMixedInnerTilesVec, /*isTransposeProducer*/ true);
+
+    if (!foldingPossible)
+      return failure();
+
+    Value output = packOp.createDestinationTensor(
+        rewriter, packOp.getLoc(), transposeOp.getOperand(0),
+        newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
+
+    rewriter.replaceOpWithNewOp<PackOp>(
+        packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
+        newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
+
+    return success();
+  }
+};
 } // namespace
 
 void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
   patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
-                  FoldProducerPackWithConsumerLinalgTransposeOp>(
+                  FoldProducerPackWithConsumerLinalgTransposeOp,
+                  FoldConsumerPackWithProducerLinalgTransposeOp>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
index ca4eb4ff679445..ed101883a40f9a 100644
--- a/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir
@@ -345,3 +345,118 @@ func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims_tile_dims_tile_s
 //      CHECK:     %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1, 2] inner_tiles = [%[[ARG3]], %[[ARG1]], %[[ARG2]]] into %[[INIT]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
 //      CHECK:     return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
 //      CHECK:   }
+
+// -----
+
+func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x57x56x2x32xf32> {
+  %0 = tensor.empty() : tensor<1x56x57x64xf32>
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x64xf32>)
+    outs(%0 : tensor<1x56x57x64xf32>)
+    permutation = [2, 0, 1, 3]
+
+  %1 = tensor.empty() : tensor<1x57x56x2x32xf32>
+  %pack = tensor.pack %transposed
+    outer_dims_perm = [0, 2, 1, 3]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
+  return %pack : tensor<1x57x56x2x32xf32>
+}
+//      CHECK: func @linalg_transpose_tensor_pack_fold(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 1, 0, 3]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]
+
+// -----
+
+func.func @linalg_transpose_tensor_pack_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> {
+  %0 = tensor.empty() : tensor<1x56x57x55xf32>
+  %transpose = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x55xf32>)
+    outs(%0 : tensor<1x56x57x55xf32>)
+    permutation = [2, 0, 1, 3]
+  
+  %1 = tensor.empty() : tensor<1x57x56x2x32xf32>
+  %pack = tensor.pack %transpose padding_value(%padding : f32)
+    outer_dims_perm = [0, 2, 1, 3]
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %1 : tensor<1x56x57x55xf32> -> tensor<1x57x56x2x32xf32>
+  return %pack : tensor<1x57x56x2x32xf32>
+}
+//      CHECK: func @linalg_transpose_tensor_pack_fold_with_padding(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x55xf32>, %[[PADDING:.+]]: f32)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[PADDING]] : f32)
+// CHECK-SAME:      outer_dims_perm = [2, 1, 0, 3]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]
+
+// -----
+
+func.func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x56x57x2x32xf32> {
+  %0 = tensor.empty() : tensor<1x56x57x64xf32>
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<56x57x1x64xf32>)
+    outs(%0 : tensor<1x56x57x64xf32>)
+    permutation = [2, 0, 1, 3]
+  
+  %1 = tensor.empty() : tensor<1x56x57x2x32xf32>
+  %pack = tensor.pack %transposed
+    inner_dims_pos = [3]
+    inner_tiles = [32]
+    into %1 : tensor<1x56x57x64xf32> -> tensor<1x56x57x2x32xf32>
+  return %pack : tensor<1x56x57x2x32xf32>
+}
+//      CHECK: func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm(
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
+//      CHECK:   %[[INIT:.+]] = tensor.empty() : tensor<1x56x57x2x32xf32>
+//      CHECK:   %[[PACK:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:      outer_dims_perm = [2, 0, 1, 3]
+// CHECK-SAME:      inner_dims_pos = [3] inner_tiles = [32] 
+// CHECK-SAME:       into %[[INIT]]
+//      CHECK:   return %[[PACK]]
+
+// -----
+
+func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %pack_dest: tensor<?x?x?x?x?x?x?xf32>, %tile_p : index, %tile_q : index, %tile_r : index) -> tensor<?x?x?x?x?x?x?xf32> {
+  %transposed = linalg.transpose
+    ins(%arg0 : tensor<?x?x?x?xf32>)
+    outs(%transpose_dest : tensor<?x?x?x?xf32>)
+    permutation = [2, 3, 0, 1]
+  
+  %pack = tensor.pack %transposed
+    outer_dims_perm = [3, 0, 2, 1]
+    inner_dims_pos = [1, 3, 2]
+    inner_tiles = [%tile_p, %tile_q, %tile_r]
+    into %pack_dest : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
+  return %pack : tensor<?x?x?x?x?x?x?xf32>
+}
+//      CHECK: #[[map:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+//      CHECK: module {
+//      CHECK:   func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[TRANSPOSE_DEST:.+]]: tensor<?x?x?x?xf32>,
+// CHECK-SAME:   %[[PACK_DEST:.+]]: tensor<?x?x?x?x?x?x?xf32>, 
+// CHECK-SAME:   %[[ARG1:.+]]: index, %[[ARG2:.+]]: index,
+// CHECK-SAME:   %[[ARG3:.+]]: index) 
+//      CHECK:     %[[c0:.+]] = arith.constant 0 : index
+//      CHECK:     %[[c1:.+]] = arith.constant 1 : index
+//      CHECK:     %[[c2:.+]] = arith.constant 2 : index
+//      CHECK:     %[[c3:.+]] = arith.constant 3 : index
+//      CHECK:     %[[dim:.+]] = tensor.dim %[[ARG0]], %[[c0]] : tensor<?x?x?x?xf32>
+//      CHECK:     %[[dim_0:.+]] = tensor.dim %[[ARG0]], %[[c1]] : tensor<?x?x?x?xf32>
+//      CHECK:     %[[dim_1:.+]] = tensor.dim %[[ARG0]], %[[c2]] : tensor<?x?x?x?xf32>
+//      CHECK:     %[[dim_2:.+]] = tensor.dim %[[ARG0]], %[[c3]] : tensor<?x?x?x?xf32>
+//      CHECK:     %[[mapped_dim0:.+]] = affine.apply #[[map:.+]]()[%[[dim_0]], %[[ARG1]]]
+//      CHECK:     %[[mapped_dim1:.+]] = affine.apply #[[map:.+]]()[%[[dim_2]], %[[ARG2]]]
+//      CHECK:     %[[mapped_dim2:.+]] = affine.apply #[[map:.+]]()[%[[dim_1]], %[[ARG3]]]
+//      CHECK:     %[[INIT:.+]] = tensor.empty(%[[mapped_dim0]], %[[mapped_dim2]], %[[dim]], %[[mapped_dim1]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) : tensor<?x?x?x?x?x?x?xf32>
+//      CHECK:     %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [1, 2, 0, 3] inner_dims_pos = [1, 3, 2] inner_tiles = [%[[ARG1]], %[[ARG2]], %[[ARG3]]] into %[[INIT]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
+//      CHECK:     return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
+//      CHECK:   }

@hanhanW hanhanW requested a review from chelini December 15, 2023 21:36
Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pushing on this, some inline comments.

/// Helper function to generate an equivalent permutation map for
/// `linalg.transpose` and `tensor.pack` which will be used after their folding
/// into a `tensor.pack`.
static bool getRemappedPermutationForTransposeAndPack(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to understand a bit better why we have such complexity for FoldConsumerPackWithProducerLinalgTransposeOp; in such cases, to me, it is a simple applyPermutation on top of the transpose permutation based on outer dims perm, inner tiles and tiles will remain the same.

Copy link
Contributor Author

@meshtag meshtag Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are correct.
I disregarded that function while working on this PR. I have updated the diff to incorporate this. Thanks!

Copy link
Contributor

@chelini chelini Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I may have confused you with this comment. What I mentioned above is true as long as the transpose do not rearrange the loops in inner_dims_pos. I left an example in the review.

@meshtag meshtag requested a review from chelini December 22, 2023 07:21

/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
/// semantics.
struct FoldConsumerPackWithProducerLinalgTransposeOp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use the transpose permutation only if we do not move the inner_dims_pos loops in the transpose operation. The semantics of outer_dim_perms is an interchange on the tile loops after the packing. Consider the following example:

func.func @main(%arg0: tensor<56x57x1x64xf32>) -> tensor<64x57x1x7x8xf32> {
 %0 = tensor.empty() : tensor<64x57x1x56xf32>
 %transposed = linalg.transpose
  ins(%arg0 : tensor<56x57x1x64xf32>)
  outs(%0 : tensor<64x57x1x56xf32>)
  permutation = [3, 1, 2, 0] // here I swap dim 3 with 0.

 %1 = tensor.empty() : tensor<64x57x1x7x8xf32>
 %pack = tensor.pack %transposed
  inner_dims_pos = [3] // here I want to pack dimension 3 (the 56 not the 64).
  inner_tiles = [8]
  into %1 : tensor<64x57x1x56xf32> -> tensor<64x57x1x7x8xf32>

 return %pack : tensor<64x57x1x7x8xf32>
}

In the example above, if we fold the transpose, we are packing dimension 64 in the input tensor arg0, but we want to pack the 56 dimension; this is because outer_dims_perm is an interchange on the tile loops after tiling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a change to address this. Let me know what you think of it.

PS: I'll fix merge conflicts once the PR is approved.

applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);

for (auto dim : innerDimsPos) {
newInnerDimsPosVec.push_back(std::find(transposePermutation.begin(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use llvm::find and avoid passing the begin and end iterators, something like:
newInnerDimsPosVec.push_back(llvm::find(transposePermutation, dim) ....

Copy link
Contributor

@chelini chelini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good to me.

auto outerDimsPerm = packOp.getOuterDimsPerm();
auto innerDimsPos = packOp.getInnerDimsPos();
SmallVector<int64_t> newInnerDimsPosVec;
SmallVector<int64_t> newOuterDimsPermVec = to_vector(transposePermutation);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please prefix with llvm namespace, llvm::to_vector.

applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);

for (auto dim : innerDimsPos) {
newInnerDimsPosVec.push_back(find(transposePermutation, dim) -
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here: llvm::find.

into %pack_dest : tensor<35x40x25x30xf32> -> tensor<3x35x5x8x5x10x5xf32>
return %pack : tensor<3x35x5x8x5x10x5xf32>
}
// CHECK: module {
Copy link
Contributor

@chelini chelini Jan 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please let's not check module here.

// CHECK-SAME: inner_tiles = [5, 10, 5]
// CHECK-SAME: into %[[VAL0]]
// CHECK: return %[[PACK]]
// CHECK: }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this last two checks can be removed.

return %pack : tensor<?x?x?x?x?x?x?xf32>
}
// CHECK: #[[map:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
// CHECK: module {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.

// CHECK: %[[VAL3:.+]] = tensor.empty(%[[VAL1]], %[[DIM1]], %[[VAL2]], %[[VAL0]], %[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?x?x?x?x?xf32>
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [1, 2, 0, 3] inner_dims_pos = [3, 1, 0] inner_tiles = [%[[ARG3]], %[[ARG4]], %[[ARG5]]] into %[[VAL3]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
// CHECK: return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
// CHECK: }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this last two checks.

into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
return %pack : tensor<1x57x56x2x32xf32>
}
// CHECK: func @linalg_transpose_tensor_pack_fold(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use CHECK-LABEL instead of CHECK

into %1 : tensor<1x56x57x55xf32> -> tensor<1x57x56x2x32xf32>
return %pack : tensor<1x57x56x2x32xf32>
}
// CHECK: func @linalg_transpose_tensor_pack_fold_with_padding(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use CHECK-LABEL instead of CHECK

into %1 : tensor<1x56x57x64xf32> -> tensor<1x56x57x2x32xf32>
return %pack : tensor<1x56x57x2x32xf32>
}
// CHECK: func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use CHECK-LABEL instead of CHECK

return %pack : tensor<3x35x5x8x5x10x5xf32>
}
// CHECK: module {
// CHECK: func.func @linalg_transpose_tensor_pack_fold_complex_inner_dims_change(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use CHECK-LABEL instead of CHECK

}
// CHECK: #[[map:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
// CHECK: module {
// CHECK: func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use CHECK-LABEL instead of CHECK

into %1 : tensor<?x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
return %pack : tensor<1x57x56x2x32xf32>
}
// CHECK: func @linalg_transpose_tensor_cast_tensor_pack_fold(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use CHECK-LABEL instead of CHECK

@meshtag meshtag force-pushed the fold_producer_transpose_with_consumer_pack branch from 0086327 to cd5491d Compare January 3, 2024 17:38
@meshtag meshtag force-pushed the fold_producer_transpose_with_consumer_pack branch from cd5491d to aedabce Compare January 4, 2024 18:13
@meshtag
Copy link
Contributor Author

meshtag commented Jan 7, 2024

Gentle ping @hanhanW

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a suggestion on lit tests. Thanks for pushing this forward!

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir Outdated Show resolved Hide resolved
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a bug about permutation or inner_dims_pos. Here is the example:

  module {
    func.func @transpose_pack(%arg0: tensor<?x32x128xbf16>) -> tensor<32x?x64x16x2xbf16> {
      %c0 = arith.constant 0 : index
      %cst = arith.constant 0.000000e+00 : bf16
      %dim = tensor.dim %arg0, %c0 : tensor<?x32x128xbf16>
      %0 = tensor.empty(%dim) : tensor<32x128x?xbf16>
      %transposed = linalg.transpose ins(%arg0 : tensor<?x32x128xbf16>) outs(%0 : tensor<32x128x?xbf16>) permutation = [1, 2, 0]
      %1 = affine.apply #map()[%dim]
      %2 = tensor.empty(%1) : tensor<32x?x64x16x2xbf16>
      %pack = tensor.pack %transposed padding_value(%cst : bf16) outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %2 : tensor<32x128x?xbf16> -> tensor<32x?x64x16x2xbf16>
      return %pack : tensor<32x?x64x16x2xbf16>
    }

EDIT:

The inner_dims_pos should be [0, 2], but [1, 0] is generated.

@meshtag
Copy link
Contributor Author

meshtag commented Jan 9, 2024

I think there is a bug about permutation or inner_dims_pos. Here is the example:

Nice catch!
I pushed a fix for this. Let me know what you think of it.

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for adding these!

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pushing on this! Do you need my help to merge the PR?

@meshtag
Copy link
Contributor Author

meshtag commented Jan 10, 2024

Do you need my help to merge the PR?

Yes, thank you!

@hanhanW hanhanW merged commit 113bce0 into llvm:main Jan 10, 2024
4 checks passed
@meshtag meshtag deleted the fold_producer_transpose_with_consumer_pack branch January 10, 2024 15:10
@vitalybuka
Copy link
Collaborator

@meshtag @hanhanW
Could you please fix or revert https://lab.llvm.org/buildbot/#/builders/168/builds/17947

@vitalybuka
Copy link
Collaborator

vitalybuka commented Jan 10, 2024

@meshtag @hanhanW Could you please fix or revert https://lab.llvm.org/buildbot/#/builders/168/builds/17947

Sorry, looks like it's already fixed https://lab.llvm.org/buildbot/#/builders/168/builds/17949

justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants