Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vectorizer] Added support to Vectorize tensor.unpack #76087

Merged
merged 12 commits into from
Feb 20, 2024

Conversation

bviyer
Copy link
Contributor

@bviyer bviyer commented Dec 20, 2023

Added support to vectorized tensor.unpack. The unpack Op is split into a vector.transfer_read, vector.transpose, vector.shape_cast and a vector.transfer_write.

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Balaji V. Iyer. (bviyer)

Changes

Added support to vectorized tensor.unpack. The unpack Op is split into a vector.transfer_read, vector.transpose, vector.shape_cast and a vector.transfer_write.


Full diff: https://github.com/llvm/llvm-project/pull/76087.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+96)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+25)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 14404d837ff748..4c456a5e671f5d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3078,7 +3078,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
 
   // TODO: Check that the correct number of vectorSizes was provided.
   for (Operation *target : targets) {
-    if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
+    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::UnPackOp>(target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f9a53a8451a601..7a9846154bf34b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
@@ -1385,6 +1386,88 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
 
   return success();
 }
+// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+//   Vector::TransferReadOp - Reads the Vector Array of Source data
+//   vector::TransposeOp - Transpose the Source
+//   ShapeCastOp - Reshapes the data based on the target.
+//   vector::TransferWriteOp. - Write the result vector back.
+
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+                                         tensor::UnPackOp unpackOp,
+                                         ArrayRef<int64_t> inputVectorSizes,
+                                         SmallVectorImpl<Value> &newResults) {
+
+  if (!unpackOp.getOuterDimsPerm().empty()) {
+    LDBG("outer dimensions perms NYI for: " << unpackOp);
+    return failure();
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(unpackOp);
+
+  RankedTensorType packTensorType = unpackOp.getSourceType();
+  auto maskType =
+      VectorType::get(packTensorType.getShape(), rewriter.getI1Type());
+  auto vectorType = VectorType::get(packTensorType.getShape(),
+                                    packTensorType.getElementType());
+  ReifiedRankedShapedTypeDims reifiedRetShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedRetShapes);
+  if (status.failed()) {
+    LDBG("Unable to reify result shapes of " << unpackOp);
+    return failure();
+  }
+
+  arith::ConstantIndexOp zeroOp =
+      rewriter.create<arith::ConstantIndexOp>(unpackOp->getLoc(), 0);
+  Value mask = rewriter.create<vector::CreateMaskOp>(
+      unpackOp.getLoc(), maskType,
+      tensor::getMixedSizes(rewriter, unpackOp.getLoc(), unpackOp.getSource()));
+
+  vector::TransferReadOp readOp = rewriter.create<vector::TransferReadOp>(
+      unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+      SmallVector<Value>(packTensorType.getRank(), zeroOp),
+      rewriter.getMultiDimIdentityMap(packTensorType.getRank()));
+
+  vector::MaskOp maskedOp =
+      cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, readOp, mask));
+
+  int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+  int64_t packRank = packTensorType.getRank();
+  auto lastDims =
+      llvm::to_vector(llvm::seq<int64_t>(packRank - numPackedDim, packRank));
+  PackingMetadata packMetadata =
+      computePackingMetadata(packRank, unpackOp.getInnerDimsPos());
+  SmallVector<int64_t> lastDimToInsertPosPerm = computePermutationVector(
+      packRank, lastDims, packMetadata.insertPositions);
+  SmallVector<int64_t> stripMineShape(packTensorType.getShape());
+  applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+  RankedTensorType stripMineTensorType =
+      RankedTensorType::Builder(packTensorType).setShape(stripMineShape);
+
+  RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+      stripMineTensorType, packMetadata.reassociations);
+  auto vecCollapsedType =
+      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+
+  vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>(
+      unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
+
+  vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
+      unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
+  tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
+      unpackOp.getLoc(), reifiedRetShapes[0], packTensorType.getElementType());
+
+  vector::TransferWriteOp writeOp = rewriter.create<vector::TransferWriteOp>(
+      unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
+      SmallVector<Value>(lastDims.size(), zeroOp),
+      SmallVector<bool>(lastDims.size(), true));
+
+  newResults.push_back(writeOp->getResult(0));
+  return success();
+}
 
 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
 /// and (3) all-zero lowPad to
@@ -1578,6 +1661,12 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
   return success();
 }
 
+static LogicalResult
+vectorizeUnpackOpPrecondition(tensor::UnPackOp unpackOp,
+                              ArrayRef<int64_t> inputVectorSizes) {
+  return success();
+}
+
 static LogicalResult
 vectorizePadOpPrecondition(tensor::PadOp padOp,
                            ArrayRef<int64_t> inputVectorSizes) {
@@ -1637,6 +1726,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::PadOp>([&](auto padOp) {
         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
       })
+      .Case<tensor::UnPackOp>([&](auto unpackOp) {
+        return vectorizeUnpackOpPrecondition(unpackOp, inputVectorSizes);
+      })
       .Default([](auto) { return failure(); });
 }
 
@@ -1724,6 +1816,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
                                           results);
           })
+          .Case<tensor::UnPackOp>([&](auto unpackOp) {
+            return vectorizeAsUnpackOp(rewriter, unpackOp, inputVectorSizes,
+                                       results);
+          })
           .Default([](auto) { return failure(); });
 
   if (failed(vectorizeResult)) {
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 610339405d1c2c..acf7276626a4c5 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -419,6 +419,31 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: func @test_vectorize_unpack
+func.func @test_vectorize_unpack(%0 : tensor<7x1136x16x16xf32>) -> tensor<100x18176xf32>  {
+    // CHECK %[[c0:.*]] = arith.constant 0 : index
+    // CHECK: %[[tr0:.*]] = vector.mask %[[m0:.*]] {{.*}} vector.transfer_read %{{.*}} : tensor<7x1136x16x16xf32>, vector<7x1136x16x16xf32> } : vector<7x1136x16x16xi1> -> vector<7x1136x16x16xf32>
+    // CHECK: %[[trans0:.*]] = vector.transpose %[[tr0]], [0, 2, 1, 3] : vector<7x1136x16x16xf32> to vector<7x16x1136x16xf32>
+    // CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]]  : vector<7x16x1136x16xf32> to vector<112x18176xf32>
+    // CHECK: %[[empt0:.*]] = tensor.empty() : tensor<100x18176xf32>
+    // CHECK: %[[tw0:.*]] = vector.transfer_write %[[sc0]], %[[empt0]]
+    // CHECK: return %[[tw0]]
+    %8 = tensor.empty() : tensor<100x18176xf32>
+    %unpack = tensor.unpack %0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %8 : tensor<7x1136x16x16xf32> -> tensor<100x18176xf32>
+    return %unpack : tensor<100x18176xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.unpack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [2, 4] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: func @test_masked_vectorize_pad
 func.func @test_masked_vectorize_pad(
   %0 : tensor<?x?xf32>, %h0 : index, %h1 : index)

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach looks good to me (% Hanhan's comments). Dropped a few comments.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
@rengolin rengolin requested a review from chelini January 19, 2024 10:34
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this needs more discussion. Let's hold on a bit

hanhanW
hanhanW previously requested changes Feb 7, 2024
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A high-level comment: IMO, inputVectorSizes should be in unpacked domain (i.e., it is for dest shapes). In this context, the behavior will be aligned with tiling. The tile_sizes in TilingInterface implementations are provided for unpacked domain (i.e., dest shape). E.g.,

// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 floordiv 32)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 mod 32)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ((d0 + 1) floordiv 32 - d0 floordiv 32 + 1)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 floordiv 16)>
// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 mod 16)>
// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0) -> ((d0 + 3) floordiv 16 - d0 floordiv 16 + 1)>
// CHECK: func.func @NCnc_to_NC
// CHECK-SAME: %[[IN:[A-Za-z0-9]+]]:
// CHECK-SAME: %[[OUT:[A-Za-z0-9]+]]:
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
// CHECK: %{{.+}} = scf.for %[[I:.+]] = %[[C0]] to %[[C256]] step %[[C2]]
// CHECK: %{{.+}} = scf.for %[[J:.+]] = %[[C0]] to %[[C128]] step %[[C4]]
// CHECK-DAG: %[[IN_I:.+]] = affine.apply #[[MAP0]](%[[I]])
// CHECK-DAG: %[[OFFSET_I:.+]] = affine.apply #[[MAP1]](%[[I]])
// CHECK-DAG: %[[IN_I_SZ:.+]] = affine.apply #[[MAP2]](%[[I]])
// CHECK-DAG: %[[IN_J:.+]] = affine.apply #[[MAP4]](%[[J]])
// CHECK-DAG: %[[OFFSET_J:.+]] = affine.apply #[[MAP5]](%[[J]])
// CHECK-DAG: %[[IN_J_SZ:.+]] = affine.apply #[[MAP6]](%[[J]])
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[IN]]
// CHECK-SAME: [%[[IN_I]], %[[IN_J]], 0, 0] [%[[IN_I_SZ]], %[[IN_J_SZ]], 32, 16]
// CHECK-SAME: : tensor<8x8x32x16xf32> to tensor<?x?x32x16xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: %[[UNPACK:.+]] = tensor.unpack
// CHECK-SAME: %[[SLICE]] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
// CHECK-SAME: into %[[EMPTY]]
// CHECK: %[[UNPACK_SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
// CHECK-SAME: [%[[OFFSET_I]], %[[OFFSET_J]]] [2, 4]
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[UNPACK_SLICE]]
// CHECK-SAME: into %{{.+}}[%[[I]], %[[J]]] [2, 4]
// CHECK: scf.yield %[[RES]]
func.func @NCnc_to_NC(%source: tensor<8x8x32x16xf32>, %dest: tensor<256x128xf32>) -> tensor<256x128xf32> {
%0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %dest : tensor<8x8x32x16xf32> -> tensor<256x128xf32>
return %0 : tensor<256x128xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.unpack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.tile_using_for %0 [2, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}

I think we really need them being aligned. Otherwise it leads to a lot of corner cases that play with other passes. E.g., IREE CPU backend has lowering_config. We use the tile_sizes to decide input_vector_sizes. If they don't match, you will have to implement other logics to make them align.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/test/Dialect/Linalg/vectorization.mlir Outdated Show resolved Hide resolved
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started review and the realized that I couldn't find the code for some of the comments marked as fixed so I'm assuming that perhaps you didn't upload the latest version of the code? :)

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
@hanhanW hanhanW dismissed their stale review February 10, 2024 15:01

I will be out for a week. @Max191 can you help on the review? (Feel free to ping me if you have any questions)

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I took a first look and dropped some comments!

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
@@ -1559,6 +1571,90 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
return success();
}

/// Vectorize a `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens for cases with OuterDimsPerms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this in c7ed75e

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update doc accordingly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Tensor/Utils/Utils.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Tensor/Utils/Utils.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good. I think we should try to support outer_dims_perm cases though. I added a comment describing how you can do that.

One of the tests looks wrong to me, although the logic for vectorization seems good to me. Maybe the test is incorrect?

Also, I think we can just use a zero constant as the padding_value for unpack. The padding_value should not show up in the result of the unpack, so it doesn't matter what value it is.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/test/Dialect/Linalg/vectorization.mlir Outdated Show resolved Hide resolved
mlir/test/Dialect/Linalg/vectorization.mlir Outdated Show resolved Hide resolved
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 2, 3, 1] : vector<2x1x16x2xf32> to vector<2x16x2x1xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like the inverse of the correct permutation. I think this transpose should be [0, 3, 1, 2].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand you. It is taking a 2x1x16x2 and converting to 2x16x2x1. This should correspond to [0(2), 2(16), 3(2), 1(1)] (note: shape in parenthesis)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, the result type of the transpose should be vector<2x(2)x1x(16)>, where the inner_tiles are next to their corresponding inner_dim. This would correspond to a permutation of [0, 3, 1, 2], which is the inverse of the permutation in the test.

I think you just need to use the inverse of the permutation that you are currently using (see the above comment).

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
(Except handling of outer Dimensions attribute)
@bviyer bviyer requested a review from dcaballe February 15, 2024 23:38
mlir/include/mlir/Dialect/Tensor/Utils/Utils.h Outdated Show resolved Hide resolved
@@ -1559,6 +1571,90 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
return success();
}

/// Vectorize a `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update doc accordingly?

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Tensor/Utils/Utils.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Tensor/Utils/Utils.cpp Outdated Show resolved Hide resolved
@bviyer bviyer requested a review from dcaballe February 16, 2024 18:40
@rengolin
Copy link
Member

Fly by comment: this is really important, thanks for working on this!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just a few nits. Please, wait for Max's approval before landing

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Tensor/Utils/Utils.h Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/Tensor/Utils/Utils.h Outdated Show resolved Hide resolved
@bviyer bviyer requested a review from Max191 February 17, 2024 05:16
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple comments about some of the permutation logic. I think I gave you some slightly out of order logic for the outer_dims_perm, so I am sorry about that, but the fixes should be pretty quick.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp Outdated Show resolved Hide resolved
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM4]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 2, 3, 1] : vector<2x1x16x2xf32> to vector<2x16x2x1xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, the result type of the transpose should be vector<2x(2)x1x(16)>, where the inner_tiles are next to their corresponding inner_dim. This would correspond to a permutation of [0, 3, 1, 2], which is the inverse of the permutation in the test.

I think you just need to use the inverse of the permutation that you are currently using (see the above comment).

mlir/test/Dialect/Linalg/vectorization.mlir Outdated Show resolved Hide resolved
@bviyer bviyer requested a review from Max191 February 20, 2024 21:49
Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for pushing on this!

@bviyer bviyer merged commit adf838d into llvm:main Feb 20, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants