Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Add direct vectorization lowering for tensor.pack ops #78660

Merged
merged 8 commits into from
Feb 7, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Jan 19, 2024

This PR adds a direct vectorization lowering of tensor.pack into mask(vector.transfer_read)->vector.shape_cast->vector.transpose->vector.transfer_write

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 19, 2024

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (Max191)

Changes

This PR adds a direct vectorization lowering of tensor.pack into mask(vector.transfer_read)->vector.shape_cast->vector.transpose->vector.transfer_write


Patch is 34.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78660.diff

9 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+51-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+1-29)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp (-1)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+147)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+7-68)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+10)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+85)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+61)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 678081837b81382..b4f18d57404cc29 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1052,6 +1052,55 @@ class PointwiseConverter : public OpRewritePattern<SrcOp> {
   }
 };
 
+class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
+public:
+  using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::TransposeOp op,
+                                PatternRewriter &rewriter) const final {
+    DenseIntElementsAttr perms;
+    if (!matchPattern(op.getPerms(), m_Constant(&perms))) {
+      return rewriter.notifyMatchFailure(op, "unmatched permutation tensor");
+    }
+
+    auto loc = op.getLoc();
+    auto input = op->getOperand(0);
+    auto resultTy = cast<ShapedType>(op.getType());
+
+    SmallVector<Value> dynDims;
+    dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
+
+    SmallVector<AffineExpr, 2> inputExprs;
+    inputExprs.resize(resultTy.getRank());
+    for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
+      auto index = permutation.index();
+      auto value = permutation.value().getZExtValue();
+      if (!resultTy.hasRank() || resultTy.isDynamicDim(index)) {
+        dynDims[index] = rewriter.create<tensor::DimOp>(loc, input, value);
+      }
+      inputExprs[value] = rewriter.getAffineDimExpr(index);
+    }
+
+    SmallVector<Value> filteredDims = condenseValues(dynDims);
+
+    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+        loc, resultTy.getShape(), resultTy.getElementType(), filteredDims);
+
+    SmallVector<AffineMap, 2> affineMaps = {
+        AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
+                       rewriter.getContext()),
+        rewriter.getMultiDimIdentityMap(resultTy.getRank())};
+
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        op, resultTy, op.getInput1(), ValueRange{emptyTensor}, affineMaps,
+        getNParallelLoopsAttrs(resultTy.getRank()),
+        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
+          nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
+        });
+    return success();
+  }
+};
+
 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 public:
   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -2408,6 +2457,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       ReverseConverter,
       RFFT2dConverter,
       TableConverter,
-      TileConverter>(patterns->getContext());
+      TileConverter,
+      TransposeConverter>(patterns->getContext());
   // clang-format on
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 8dc2d27bd545ff8..b3fbc7dd0b22c19 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -19,7 +19,6 @@
 #include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -985,31 +984,6 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
   }
 };
 
-class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
-public:
-  using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tosa::TransposeOp op,
-                                PatternRewriter &rewriter) const final {
-    SmallVector<int64_t> constantPerms;
-    if (failed(op.getConstantPerms(constantPerms)))
-      return failure();
-
-    Location loc = op.getLoc();
-    // The verifier should have made sure we have a valid permutation tensor.
-    assert(isPermutationVector(constantPerms) && "Expected valid permutation");
-    SmallVector<OpFoldResult> inputSizes =
-        tensor::getMixedSizes(rewriter, loc, op.getInput1());
-    auto permutedSizes =
-        applyPermutation<OpFoldResult>(inputSizes, constantPerms);
-
-    auto permutedInit = rewriter.create<tensor::EmptyOp>(
-        loc, permutedSizes, op.getInput1().getType().getElementType());
-    rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
-        op, op.getInput1(), permutedInit, constantPerms);
-    return success();
-  }
-};
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
@@ -1030,8 +1004,6 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
       MatMulConverter,
       MaxPool2dConverter,
       AvgPool2dConverter,
-      FullyConnectedConverter,
-      TransposeConverter
-  >(patterns->getContext());
+      FullyConnectedConverter>(patterns->getContext());
   // clang-format on
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 096969391e51b9d..5312dc164c26c5e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -60,7 +60,6 @@ struct TosaToLinalgNamed
     target.addIllegalOp<tosa::AvgPool2dOp>();
     target.addIllegalOp<tosa::MatMulOp>();
     target.addIllegalOp<tosa::FullyConnectedOp>();
-    target.addIllegalOp<tosa::TransposeOp>();
 
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5254aac976f462d..2e58eb3376a1c8e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3134,7 +3134,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
 
   // TODO: Check that the correct number of vectorSizes was provided.
   for (Operation *target : targets) {
-    if (!isa<linalg::LinalgOp, tensor::PadOp>(target)) {
+    if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp>(target)) {
       return mlir::emitSilenceableFailure(target->getLoc())
              << "Unsupported Op, cannot vectorize";
     }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5d99951ef09a92b..b56289b560272d0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,10 +19,14 @@
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -30,7 +34,9 @@
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/iterator_range.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 #include <optional>
 #include <type_traits>
@@ -1393,6 +1399,117 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
   return success();
 }
 
+/// Given a tensor::PackOp, return the permutation from the "tiled"
+/// shape to the "packed" shape, defined as the following:
+/// The "packed" shape is the same as the `dest` shape of the pack op.
+/// The "tiled" shape is a permutation of the `dest` shape such that
+/// each outer dimension is in the original `source` order, and the
+/// inner_tile dimensions immediately follow their corresponding outer
+/// dimension.
+/// i.e. for the following tensor.pack:
+/// ```mlir
+/// %pack = tensor.pack %0 padding_value(%1) 
+///   outer_dims_perm = [0, 2, 1] 
+///   inner_dims_pos = [2, 1] 
+///   inner_tiles = [16, 2] 
+///   into %2 : tensor<32x8x16> -> tensor<32x1x4x16x2>
+/// ```
+/// The "packed" shape is `32x1x4x16x2`
+/// The "tiled" shape is `32x(4x2)x(1x16)`
+static SmallVector<int64_t> getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
+  auto innerTiles = packOp.getInnerTiles();
+  int64_t srcRank = packOp.getSourceRank();
+  auto innerDimsPos = packOp.getInnerDimsPos();
+  if (innerDimsPos.empty())
+    innerDimsPos = to_vector(llvm::seq<int64_t>(innerTiles.size()));
+  auto outerDimsPerm = packOp.getOuterDimsPerm();
+  if (outerDimsPerm.empty())
+    outerDimsPerm = to_vector(llvm::seq<int64_t>(srcRank));
+  auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t { 
+    int64_t srcIdx;
+    if (idx >= srcRank)
+      srcIdx = innerDimsPos[idx - srcRank];
+    else
+      srcIdx = outerDimsPerm[idx];
+    int64_t tiledIdx = srcIdx;
+    for (int64_t pos : innerDimsPos)
+      if (pos < srcIdx)
+        tiledIdx++;
+    if (idx >= srcRank)
+      tiledIdx++;
+    return tiledIdx;
+  };
+  SmallVector<int64_t> perm;
+  for (int i = 0; i < packOp.getDestRank(); i++) 
+    perm.push_back(packedIdxToTiledIdx(i));
+  return perm;
+}
+
+/// Given a tensor::PackOp, return the "tiled" `dest` shape as described
+/// above in `getTiledShapeToPackedShapePerm`.
+static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp) {
+  auto perm = getTiledShapeToPackedShapePerm(packOp);
+  auto destShape = packOp.getDestType().getShape();
+  return applyPermutation(destShape, invertPermutationVector(perm));
+}
+
+/// 
+static LogicalResult
+vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
+                       ArrayRef<int64_t> inputVectorSizes,
+                       SmallVectorImpl<Value> &newResults) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(packOp);
+
+  Location loc = packOp.getLoc();
+  auto padValue = packOp.getPaddingValue();
+  if (!padValue) {
+    padValue = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType()));
+  }
+  int64_t inputRank = inputVectorSizes.size();
+  int64_t outputRank = packOp.getDestRank();
+  auto maskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
+  auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
+
+  ReifiedRankedShapedTypeDims reifiedReturnShapes;
+  LogicalResult status =
+      cast<ReifyRankedShapedTypeOpInterface>(packOp.getOperation())
+          .reifyResultShapes(rewriter, reifiedReturnShapes);
+  (void)status; // prevent unused variable warning on non-assert builds
+  assert(succeeded(status) && "failed to reify result shapes");
+  auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, reifiedReturnShapes[0],
+                                                  padValue.getType());
+  SmallVector<OpFoldResult> mixedSourceDims =
+      tensor::getMixedSizes(rewriter, loc, packOp.getSource());
+  Value mask =
+      rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+  auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+      loc,
+      /*vectorType=*/vectorType,
+      /*source=*/packOp.getSource(),
+      /*indices=*/SmallVector<Value>(inputRank, zero),
+      /*padding=*/padValue,
+      /*inBounds=*/SmallVector<bool>(inputRank, true));
+  auto maskedOp = cast<vector::MaskOp>(
+      mlir::vector::maskOperation(rewriter, transferReadOp, mask));
+  // ShapeCast
+  auto tiledPackShape = getTiledPackShape(packOp);
+  auto tiledPackType = VectorType::get(tiledPackShape, packOp.getDestType().getElementType());
+  auto shapeCastOp = rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedOp->getResult(0));
+  auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
+  auto transposeOp = rewriter.create<vector::TransposeOp>(loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
+  Operation *write = rewriter.create<vector::TransferWriteOp>(
+      loc,
+      /*vector=*/transposeOp->getResult(0),
+      /*source=*/emptyOp,
+      /*indices=*/SmallVector<Value>(outputRank, zero),
+      /*inBounds=*/SmallVector<bool>(outputRank, true));
+  newResults.push_back(write->getResult(0));
+  return success();
+}
+
 /// Vectorize a `padOp` with (1) static result type, (2) constant padding value
 /// and (3) all-zero lowPad to
 ///   `transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))`.
@@ -1585,6 +1702,30 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
   return success();
 }
 
+static LogicalResult
+vectorizePackOpPrecondition(tensor::PackOp packOp,
+                           ArrayRef<int64_t> inputVectorSizes) {
+  auto padValue = packOp.getPaddingValue();
+  if (padValue && getConstantIntValue(padValue) != std::nullopt) {
+    LDBG("pad value is not constant: " << packOp << "\n");
+    return failure();
+  }
+
+  ArrayRef<int64_t> resultTensorShape = packOp.getSourceType().getShape();
+  if (failed(isValidMaskedInputVector(resultTensorShape, inputVectorSizes)))
+    return failure();
+
+  if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) {
+        std::optional<int64_t> res = getConstantIntValue(v);
+        return !res.has_value();
+      })) {
+    LDBG("inner_tiles must be constant: " << packOp << "\n");
+    return failure();
+  }
+
+  return success();
+}
+
 static LogicalResult
 vectorizePadOpPrecondition(tensor::PadOp padOp,
                            ArrayRef<int64_t> inputVectorSizes) {
@@ -1644,6 +1785,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::PadOp>([&](auto padOp) {
         return vectorizePadOpPrecondition(padOp, inputVectorSizes);
       })
+      .Case<tensor::PackOp>([&](auto packOp) {
+        return vectorizePackOpPrecondition(packOp, inputVectorSizes);
+      })
       .Default([](auto) { return failure(); });
 }
 
@@ -1732,6 +1876,9 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             return vectorizeAsTensorPadOp(rewriter, padOp, inputVectorSizes,
                                           results);
           })
+          .Case<tensor::PackOp>([&](auto packOp) {
+            return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes, results);
+          })
           .Default([](auto) { return failure(); });
 
   if (failed(vectorizeResult)) {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 6616ea7cf699fa5..aa010e759a0f201 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -88,8 +88,7 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
 // CHECK-LABEL: @fully_connected
 func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
-  // CHECK: %[[TRANSPOSEDINIT:.+]] = tensor.empty() : tensor<3x6xf32>
-  // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT]] : tensor<3x6xf32>) permutation = [1, 0]
+  // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
 
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
@@ -111,7 +110,7 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2
 // CHECK-LABEL: @quantized_fully_connected
 func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
-  // CHECK: %[[TRANSPOSE:.+]] =  linalg.transpose ins(%arg1 : tensor<6x3xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xi8>) permutation = [1, 0]
+  // CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
 
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
@@ -137,7 +136,7 @@ func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %
   // CHECK: %[[C0:.+]] = arith.constant 0 : index
   // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
   // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
-  // CHECK: %[[TRANSPOSED:.+]] = linalg.transpose ins(%arg1 : tensor<6x3xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x6xf32>) permutation = [1, 0]
+  // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
 
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
@@ -378,7 +377,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
 // CHECK-LABEL: @conv2d_i8
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
   // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
-  // HWCF: %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x1x1x27xi8>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<1x1x27x28xi8>) permutation = [1, 2, 3, 0]
+  // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
   // CHECK:   arith.extsi
@@ -399,7 +398,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
 // CHECK-LABEL: @conv2d_f32
 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
   // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
-  // HWCF: %[[TRANSPOSE:.+]] =  linalg.transpose ins(%arg1 : tensor<28x3x3x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x3x27x28xf32>) permutation = [1, 2, 3, 0]
+  // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
 
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
@@ -678,7 +677,7 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
 // CHECK-LABEL: @conv3d_f32
 func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK-DAG:  %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
-  // CHECK-DAG:  %[[TRANSPOSE:.+]] = linalg.transpose ins(%arg1 : tensor<28x3x4x5x27xf32>) outs(%[[TRANSPOSEDINIT:.+]] : tensor<3x4x5x27x28xf32>) permutation = [1, 2, 3, 4, 0]
+  // CHECK-DAG:  %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
   // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
   // CHECK:      %[[BROADCAST:.+]] = linalg.generic
   // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
@@ -702,7 +701,7 @@ func.func @conv3...
[truncated]

Copy link

github-actions bot commented Jan 19, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@MaheshRavishankar
Copy link
Contributor

I think you have an unintended commit sneak in

@Max191
Copy link
Contributor Author

Max191 commented Jan 19, 2024

I think you have an unintended commit sneak in

Yep, rebased now

@dcaballe dcaballe requested a review from bviyer January 19, 2024 02:32
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a test that all the dim sizes are dynamic expect inner tile sizes? That will give us a better picture about how masking work in the vectorization.

btw, I think it is better to keep consistent with the TilingInterface implementation. When we provide sizes to the op, they are applied on destination outer dimensions. See the comment in lit test for examples.

Comment on lines 1723 to 1724
std::optional<int64_t> res = getConstantIntValue(v);
return !res.has_value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can merge these two lines to return getConstantIntValue(v).has_value().

vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
if (padValue && getConstantIntValue(padValue) != std::nullopt) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think using .has_value is better.

return applyPermutation(destShape, invertPermutationVector(perm));
}

///
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a doc

loc, tiledPackType, maskedOp->getResult(0));
auto tiledShapeToPackedShapePerm = getTiledShapeToPackedShapePerm(packOp);
auto transposeOp = rewriter.create<vector::TransposeOp>(
loc, shapeCastOp->getResult(0), tiledShapeToPackedShapePerm);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's use the access from the op, i.e., shapeCastOp.getResult()).

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize %0 vector_sizes [8, 16] : !transform.any_op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to keep consistent with the TilingInterface implementation. When we provide sizes to the op, they are applied on destination outer dimensions. In this context, the vector_sizes should be [4, 1]. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will also be aligned with how we set lowering_config in IREE. We can look at lowering_config when inferring input vector sizes gets tricky.

Comment on lines 1507 to 1512
Operation *write = rewriter.create<vector::TransferWriteOp>(
loc,
/*vector=*/transposeOp->getResult(0),
/*source=*/emptyOp,
/*indices=*/SmallVector<Value>(outputRank, zero),
/*inBounds=*/SmallVector<bool>(outputRank, true));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to mask the write if the shape and provided input vector do not match. If you follow what I suggest about input_vector changes, the check will be something like:

  bool needMaskForWrite = llvm::any_of(
      llvm::zip_equal(inputVectorSizes, packOp.getDestType().getShape().drop_back(innerTiles.size())),
      [](auto it) { return std::get<0>(it) != std::get<1>(it); });

@rengolin rengolin requested a review from chelini January 19, 2024 10:33
@Max191 Max191 requested a review from hanhanW January 19, 2024 22:27
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks! A few high-level comments. It would be great to try to refactor/generalize a bit more the utilities to create xfer reads/writes.

auto packedIdxToTiledIdx = [&](int64_t idx) -> int64_t {
int64_t srcIdx;
if (idx >= srcRank)
srcIdx = innerDimsPos[idx - srcRank];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some comments to the code that help understanding the flow of what's happening?

int64_t tiledIdx = srcIdx;
for (int64_t pos : innerDimsPos)
if (pos < srcIdx)
tiledIdx++;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preincrement

return applyPermutation(destShape, invertPermutationVector(perm));
}

/// Create a masked TransferReadOp from `source` with shape `readShape`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we refactor/generalize createTransferRead and createTransferWrite so that we can use it in other place where we are doing the same?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Max already addressed the comment. It is used by pack vectorization and pad vectorization. See the delete lines below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are more cases than those two. I cas referring to the xfer reads/writes created when vectorizing a generic op, for example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vectorizing generic ops seems to be holding some additional state.

Operation *read = rewriter.create<vector::TransferReadOp>(
loc, readType, opOperand->get(), indices, readMap);
read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
Value readValue = read->getResult(0);

I don't see how to easily reuse code here, but maybe I'm not fully understanding what is going on in the GenericOp vectorization.

Comment on lines 1521 to 1524
/// transfer_write_in_bounds(
/// transpose(
/// shape_cast(
/// transfer_read_masked(pack_source, pad_value))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a before and after example instead might be clearer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

(void)status; // prevent unused variable warning on non-assert builds
assert(succeeded(status) && "failed to reify result shapes");

// Create masked TransferReadOp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: period at the end of comments per coding standard (here and below)

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two big comments + few nits.

  1. A method can be refactored to Tensor/Utils/.
  2. How to handle the padding value for transfer_read is unclear to me. @dcaballe can you provide some guidance? See the inlined comment for more details: [mlir] Add direct vectorization lowering for tensor.pack ops #78660 (comment)

/// The "packed" shape is `32x1x4x16x2`
/// The "tiled" shape is `32x(4x2)x(1x16)`
static SmallVector<int64_t>
getTiledShapeToPackedShapePerm(tensor::PackOp packOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is doing something similar to a snippet of logic in lowerPack, which also needs the permutation for transpose op. It would be good if we can refactor it to Tensor/Utils/Utils.[cpp|h] and use it in both places.

// 2. Compute the permutation vector to shuffle packed shape into the shape
// before any outer or inner permutations have been applied. The permutation
// can be obtained from two permutations:
// a) Compute the permutation vector to move the last `numPackedDims` into
// the `innerPosDims` of a shape of rank `packedRank`.
// b) Compute the permutation vector to move outer dims if the pack op
// has outer_dims_perm.
// Apply (b) permutation on (a) permutation to get the final permutation.
int64_t numPackedDims = packOp.getInnerDimsPos().size();
int64_t packedRank = packedTensorType.getRank();
auto lastDims = llvm::to_vector(
llvm::seq<int64_t>(packedRank - numPackedDims, packedRank));
PackingMetadata packingMetadata = computePackingMetadata(
packedTensorType.getRank(), packOp.getInnerDimsPos());
SmallVector<int64_t> innerPositionsPerm = computePermutationVector(
packedRank, lastDims, packingMetadata.insertPositions);
SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
if (!outerPerm.empty())
applyPermutationToVector(outerPos, outerPerm);
SmallVector<int64_t> outerPositionPerm = computePermutationVector(
packedRank, packingMetadata.outerPositions, outerPos);
SmallVector<int64_t> packedToStripMinedShapePerm = innerPositionsPerm;
applyPermutationToVector(packedToStripMinedShapePerm, outerPositionPerm);
// 3. Compute the stripMinedShape: this is the packed shape before any outer
// or inner permutations have been applied.
SmallVector<int64_t> stripMinedShape(packedTensorType.getShape());
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);

Comment on lines 1503 to 1505
bool needMaskForWrite =
llvm::any_of(llvm::zip(inputVectorSizes, destShape),
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a dangerous logic behind this snippet. The size of inputVectorSizes and destShape are different. I think we should use llvm::zip_equal and use destShape.drop_back(innerTiles.size()) here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and this would work for a while but it's not covering all the cases we cover in the main path. That's one of the problems with generic vectorization.

return applyPermutation(destShape, invertPermutationVector(perm));
}

/// Create a masked TransferReadOp from `source` with shape `readShape`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Max already addressed the comment. It is used by pack vectorization and pad vectorization. See the delete lines below.

Comment on lines 1521 to 1524
/// transfer_write_in_bounds(
/// transpose(
/// shape_cast(
/// transfer_read_masked(pack_source, pad_value))))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@dcaballe
Copy link
Contributor

dcaballe commented Feb 1, 2024

What is the plan for this then? IIRC we needed direct vectorization here not only for performance but also because the vector code for this op was canonicalized away and we ended up with scalar data copies. It would be great to have some clarity as @bviyer also has the unpack counterpart here: #76087

@Max191
Copy link
Contributor Author

Max191 commented Feb 1, 2024

What is the plan for this then? IIRC we needed direct vectorization here not only for performance but also because the vector code for this op was canonicalized away and we ended up with scalar data copies. It would be great to have some clarity as @bviyer also has the unpack counterpart here: #76087

I'm going to pick this back up tomorrow and address comments so we can land it soon.

@dcaballe
Copy link
Contributor

dcaballe commented Feb 1, 2024

Ok, I'm anticipating a problem here that I'm hitting right now: the direct vectorization pattern has to generate good code when the target doesn't support masking (e.g., Arm Neon). Currently, if we try to vectorize a tensor.pad op without masking, we generate transfer reads with the in_bounds flags set to false, which causes many other problems down the road (e.g., some other canonicalization patterns not triggering). We have to make sure that whatever direct ops we generate, we have an good path for targets without a mask. I haven't thought too much about this but perhaps we can try to apply peeling to the pad op when masking is not supported.

@Max191
Copy link
Contributor Author

Max191 commented Feb 1, 2024

I rebased and addressed most of the comments. Let me know what you guys think.

#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'll fix these extra includes in the next round of comments

Comment on lines +1730 to +1734
/// TODO: Use a matcher to check for a constant padding value.
static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized there is no matcher for constant float values like there is for constant int values. I didn't know exactly how this should be done.

/// If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
// Case 1: Check for Constant integer.
if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) {
APSInt intVal;
if (matchPattern(val, m_ConstantInt(&intVal)))
return intVal.getSExtValue();
return std::nullopt;
}
// Case 2: Check for IntegerAttr.
Attribute attr = llvm::dyn_cast_if_present<Attribute>(ofr);
if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr))
return intAttr.getValue().getSExtValue();
return std::nullopt;
}

For int, getConstantIntValue passes an uninitialized APSInt. For the analogous float matcher, I'm not sure the best way to implement a similar function, since there are several possible floating point semantics. I suppose there could be a util function that simply checks if the value is constant and does not return the value. I think then an arbitrary float semantic could be used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, I think the current implementation is okay, which checks if they are from arith.constant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Max, I just found this might address the concern.

// Return the constant attribute, or null if the Operation isn't a constant.
Attribute getConstantAttr(Operation *constantOp) {
  Attribute constant;
  matchPattern(value.getDefiningOp(), m_Constant());
  return constant;
}

https://mlir.llvm.org/getting_started/Faq/#many-dialects-define-a-constant-operation-how-do-i-get-a-constant-value-generically

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just few nits.

Comment on lines 1431 to 1435
auto sourceShape = llvm::dyn_cast<ShapedType>(source.getType()).getShape();
if (sourceShape.size() == readShape.size() &&
llvm::all_of(llvm::zip_equal(readShape, sourceShape), [](auto it) {
return std::get<0>(it) != ShapedType::kDynamic &&
std::get<0>(it) == std::get<1>(it);
})) {
return transferReadOp;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few nits here:

  1. Drop llvm:: for dyn_cast. It is more common in the MLIR codebase.
  2. Should sourceShape.size() == readShape.size() be an assertion? It looks wrong to me when it is happening.
  3. std::get<0>(it) != ShapedType::kDynamic is asserted in the entry point, so we don't need them. In this case, we can probably update it to llvm::zip_eqaul(readShape, sourceShape, std::equal_to<int64_t>()) { ... }

Comment on lines 1473 to 1475
llvm::zip_equal(inputVectorSizes,
destShape.take_front(inputVectorSizes.size())),
[](auto it) { return std::get<0>(it) != std::get<1>(it); });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace the lambda with std::not_equal_to<int64_t>()?

Comment on lines 1492 to 1506
/// As in the following example:
/// ```mlir
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
/// ```
/// This pack would be vectorized to:
/// ```mlir
/// %load = vector.mask %mask {
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
/// {in_bounds = [true, true, true]} :
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
/// to vector<32x4x2x1x16xf32>
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
/// %write = vector.transfer_write %transpose,
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
/// {in_bounds = [true, true, true, true, true]}
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] I think adding a blank line before and after the IR is more readable. Also, maybe we should follow the convention, which does not have ``` in examples. E.g.,

/// Materialize a buffer allocation for the given vector.mask op and bufferize
/// the op, including its region. E.g.:
///
/// %0 = vector.mask {
/// vector.transfer_write %v, %t : vector<16xf32>, tensor<?xf32>
/// } : vector<16xi1> -> tensor<?xf32>
///
/// is lowered to:
///
/// %alloc = memref.alloc
/// bufferization.materialize_in_destination %t in %subview
/// vector.mask {
/// vector.transfer_write %arg0, %alloc : vector<16xf32>, memref<?xf32>
/// } : vector<16xi1>
/// %0 = bufferization.to_tensor %alloc restrict writable
///
/// In addition to rewriting the IR as shown above, this function returns the
/// newly allocated buffer. The `insertionPoint` parameter can be used to
/// specify a custom insertion point for the buffer allocation.

So perhaps we can update the comment with below:

Suggested change
/// As in the following example:
/// ```mlir
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
/// ```
/// This pack would be vectorized to:
/// ```mlir
/// %load = vector.mask %mask {
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
/// {in_bounds = [true, true, true]} :
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
/// to vector<32x4x2x1x16xf32>
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
/// %write = vector.transfer_write %transpose,
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
/// {in_bounds = [true, true, true, true, true]}
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
/// As in the following example:
///
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
///
/// This pack would be vectorized to:
///
/// %load = vector.mask %mask {
/// vector.transfer_read %arg0[%c0, %c0, %c0], %cst
/// {in_bounds = [true, true, true]} :
/// tensor<32x7x16xf32>, vector<32x8x16xf32>
/// } : vector<32x8x16xi1> -> vector<32x8x16xf32>
/// %shape_cast = vector.shape_cast %load : vector<32x8x16xf32>
/// to vector<32x4x2x1x16xf32>
/// %transpose = vector.transpose %shape_cast, [0, 1, 3, 4, 2]
/// : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
/// %write = vector.transfer_write %transpose,
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
/// {in_bounds = [true, true, true, true, true]}
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>

Comment on lines +1730 to +1734
/// TODO: Use a matcher to check for a constant padding value.
static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, I think the current implementation is okay, which checks if they are from arith.constant.

@Max191 Max191 merged commit 7880b2c into llvm:main Feb 7, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants