Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Add patterns for efficient i4 -> i8 conversion emulation #79494

Merged
merged 2 commits into from
Jan 30, 2024

Conversation

dcaballe
Copy link
Contributor

This PR adds new patterns to improve the generated vector code for the emulation of any conversion that have to go through an i4 -> i8 type extension (only signed extensions are supported for now). This will impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4 -> f8/f16/f32/f64.

The asm code generated for the supported cases is significantly better after this PR for both x86 and aarch64.

@llvmbot
Copy link

llvmbot commented Jan 25, 2024

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR adds new patterns to improve the generated vector code for the emulation of any conversion that have to go through an i4 -> i8 type extension (only signed extensions are supported for now). This will impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4 -> f8/f16/f32/f64.

The asm code generated for the supported cases is significantly better after this PR for both x86 and aarch64.


Full diff: https://github.com/llvm/llvm-project/pull/79494.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+143-20)
  • (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+33)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ead7d645cb5bb3d..fdc2d2d7e0f7fa6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -642,9 +642,9 @@ struct BitCastRewriter {
 
   BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
 
-  /// Verify that the preconditions for the rewrite are met.
-  LogicalResult precondition(PatternRewriter &rewriter,
-                             VectorType preconditionVectorType, Operation *op);
+  /// Verify that general preconditions for the rewrite are met.
+  LogicalResult commonPrecondition(PatternRewriter &rewriter,
+                                   VectorType preconditionType, Operation *op);
 
   /// Precompute the metadata for the rewrite.
   SmallVector<BitCastRewriter::Metadata>
@@ -652,9 +652,9 @@ struct BitCastRewriter {
 
   /// Rewrite one step of the sequence:
   ///   `(shuffle -> and -> shiftright -> shiftleft -> or)`.
-  Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
-                    Value runningResult,
-                    const BitCastRewriter::Metadata &metadata);
+  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
+                           Value initialValue, Value runningResult,
+                           const BitCastRewriter::Metadata &metadata);
 
 private:
   /// Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,54 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
   LDBG("\n" << enumerator.sourceElementRanges);
 }
 
-LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
-                                            VectorType precondition,
-                                            Operation *op) {
-  if (precondition.getRank() != 1 || precondition.isScalable())
+/// Verify that the precondition type meets the common preconditions for any
+/// conversion.
+static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
+                                                  VectorType preconditionType,
+                                                  Operation *op) {
+  if (preconditionType.getRank() != 1 || preconditionType.isScalable())
     return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
 
   // TODO: consider relaxing this restriction in the future if we find ways
   // to really work with subbyte elements across the MLIR/LLVM boundary.
-  int64_t resultBitwidth = precondition.getElementTypeBitWidth();
+  unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
   if (resultBitwidth % 8 != 0)
     return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
 
   return success();
 }
 
+LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
+                                                  VectorType preconditionType,
+                                                  Operation *op) {
+  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
+    return rewriter.notifyMatchFailure(op, "types are not vector");
+
+  return commonConversionPrecondition(rewriter, preconditionType, op);
+}
+
+/// Verify that source and destination element types meet the precondition for
+/// the supported aligned conversion cases. Alignment means that the either the
+/// source element type is multiple of the destination element type or the other
+/// way around.
+///
+/// NOTE: This method assumes that common conversion preconditions are met.
+static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
+                                                   VectorType srcType,
+                                                   VectorType dstType,
+                                                   Operation *op) {
+  unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
+  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+  unsigned byteBitwidth = 8;
+
+  // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
+  if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+      (dstElemBitwidth % srcElemBitwidth) != 0)
+    return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+
+  return success();
+}
+
 SmallVector<BitCastRewriter::Metadata>
 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
   SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +808,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
   return result;
 }
 
-Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
-                                   Value initialValue, Value runningResult,
-                                   const BitCastRewriter::Metadata &metadata) {
+Value BitCastRewriter::genericRewriteStep(
+    PatternRewriter &rewriter, Location loc, Value initialValue,
+    Value runningResult, const BitCastRewriter::Metadata &metadata) {
   // Create vector.shuffle from the metadata.
   auto shuffleOp = rewriter.create<vector::ShuffleOp>(
       loc, initialValue, initialValue, metadata.shuffles);
@@ -810,6 +843,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
   return runningResult;
 }
 
+/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+                                    Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(4) &&
+         "Expected i4 type");
+
+  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+  int64_t vecDimSize = srcVecType.getShape().back();
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i4Toi8BitwidthFactor = 2;
+  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
+  // byte are place in one vector and the high i4 elements in another vector.
+  constexpr int8_t bitsToShift = 4;
+  auto shiftValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(i8VecType, bitsToShift));
+  Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
+  Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
+  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
+
+  // 3. Interleave low and high i8 elements using a shuffle.
+  SmallVector<int64_t> interleaveMaskValues;
+  interleaveMaskValues.reserve(vecDimSize);
+  for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
+    interleaveMaskValues.push_back(i);
+    interleaveMaskValues.push_back(i + (vecDimSize / 2));
+  }
+
+  return rewriter.create<vector::ShuffleOp>(
+      loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+}
+
 namespace {
 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
 /// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +900,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     VectorType sourceVectorType = bitCastOp.getSourceVectorType();
     VectorType targetVectorType = bitCastOp.getResultVectorType();
     BitCastRewriter bcr(sourceVectorType, targetVectorType);
-    if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+    if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
       return failure();
 
     // Perform the rewrite.
@@ -839,8 +910,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     Value runningResult;
     for (const BitCastRewriter ::Metadata &metadata :
          bcr.precomputeMetadata(shuffledElementType)) {
-      runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
-                                      runningResult, metadata);
+      runningResult = bcr.genericRewriteStep(
+          rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
     }
 
     // Finalize the rewrite.
@@ -885,7 +956,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
     VectorType sourceVectorType = bitCastOp.getSourceVectorType();
     VectorType targetVectorType = bitCastOp.getResultVectorType();
     BitCastRewriter bcr(sourceVectorType, targetVectorType);
-    if (failed(bcr.precondition(
+    if (failed(bcr.commonPrecondition(
             rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
       return failure();
 
@@ -896,8 +967,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
         cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
     for (const BitCastRewriter::Metadata &metadata :
          bcr.precomputeMetadata(shuffledElementType)) {
-      runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
-                                      sourceValue, runningResult, metadata);
+      runningResult = bcr.genericRewriteStep(
+          rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
     }
 
     // Finalize the rewrite.
@@ -915,6 +986,52 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
     return success();
   }
 };
+
+/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+///   extsi vector<8xi4> -> vector<8xi32>
+///     is rewriten as
+///   sequence of shuffles and bitwise of for i4 -> i8
+///   extsi vector<8xi8> -> vector<8xi32>
+///
+///   sitofp vector<8xi4> -> vector<8xf32>
+///     is rewriten as
+///   sequence of shuffles and bitwise of for i4 -> i8
+///   sitofp vector<8xi8> -> vector<8xf32>
+///
+template <typename ConversionOpType>
+struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
+  using OpRewritePattern<ConversionOpType>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
+                                PatternRewriter &rewriter) const override {
+    // Set up the BitCastRewriter and verify the preconditions.
+    Value srcValue = conversionOp.getIn();
+    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+    auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+    if (failed(
+            commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
+      return failure();
+
+    // Check general alignment preconditions.
+    if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
+                                             conversionOp)))
+      return failure();
+
+    // Perform the rewrite.
+    Value subByteExt =
+        rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+
+    // Finalize the rewrite.
+    rewriter.replaceOpWithNewOp<ConversionOpType>(
+        conversionOp, conversionOp.getType(), subByteExt);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -936,4 +1053,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
   patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
                RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
                                                     benefit);
+
+  // Patterns for aligned cases. We set higher priority as they are expected to
+  // generate better performance for aligned cases.
+  patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
+               RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
+      patterns.getContext(), benefit.getBenefit() + 1);
 }
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index a600fa955b17003..c4fbb4c219b9170 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -193,6 +193,39 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
   return %1 : vector<8xi17>
 }
 
+// CHECK-LABEL: func.func @aligned_extsi(
+func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
+  // CHECK: arith.shli
+  // CHECK: arith.shrsi
+  // CHECK: arith.shrsi
+  // CHECK: vector.shuffle
+  // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+  %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_base_case(
+func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
+  // CHECK: arith.shli
+  // CHECK: arith.shrsi
+  // CHECK: arith.shrsi
+  // CHECK: vector.shuffle
+  // CHECK-NOT: arith.extsi
+  %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_sitofp(
+func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
+  // CHECK: arith.shli
+  // CHECK: arith.shrsi
+  // CHECK: arith.shrsi
+  // CHECK: shuffle
+  // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+  %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

@llvmbot
Copy link

llvmbot commented Jan 25, 2024

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

Changes

This PR adds new patterns to improve the generated vector code for the emulation of any conversion that have to go through an i4 -> i8 type extension (only signed extensions are supported for now). This will impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4 -> f8/f16/f32/f64.

The asm code generated for the supported cases is significantly better after this PR for both x86 and aarch64.


Full diff: https://github.com/llvm/llvm-project/pull/79494.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+143-20)
  • (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+33)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ead7d645cb5bb3d..fdc2d2d7e0f7fa6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -642,9 +642,9 @@ struct BitCastRewriter {
 
   BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
 
-  /// Verify that the preconditions for the rewrite are met.
-  LogicalResult precondition(PatternRewriter &rewriter,
-                             VectorType preconditionVectorType, Operation *op);
+  /// Verify that general preconditions for the rewrite are met.
+  LogicalResult commonPrecondition(PatternRewriter &rewriter,
+                                   VectorType preconditionType, Operation *op);
 
   /// Precompute the metadata for the rewrite.
   SmallVector<BitCastRewriter::Metadata>
@@ -652,9 +652,9 @@ struct BitCastRewriter {
 
   /// Rewrite one step of the sequence:
   ///   `(shuffle -> and -> shiftright -> shiftleft -> or)`.
-  Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
-                    Value runningResult,
-                    const BitCastRewriter::Metadata &metadata);
+  Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
+                           Value initialValue, Value runningResult,
+                           const BitCastRewriter::Metadata &metadata);
 
 private:
   /// Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,54 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
   LDBG("\n" << enumerator.sourceElementRanges);
 }
 
-LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
-                                            VectorType precondition,
-                                            Operation *op) {
-  if (precondition.getRank() != 1 || precondition.isScalable())
+/// Verify that the precondition type meets the common preconditions for any
+/// conversion.
+static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
+                                                  VectorType preconditionType,
+                                                  Operation *op) {
+  if (preconditionType.getRank() != 1 || preconditionType.isScalable())
     return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
 
   // TODO: consider relaxing this restriction in the future if we find ways
   // to really work with subbyte elements across the MLIR/LLVM boundary.
-  int64_t resultBitwidth = precondition.getElementTypeBitWidth();
+  unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
   if (resultBitwidth % 8 != 0)
     return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
 
   return success();
 }
 
+LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
+                                                  VectorType preconditionType,
+                                                  Operation *op) {
+  if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
+    return rewriter.notifyMatchFailure(op, "types are not vector");
+
+  return commonConversionPrecondition(rewriter, preconditionType, op);
+}
+
+/// Verify that source and destination element types meet the precondition for
+/// the supported aligned conversion cases. Alignment means that the either the
+/// source element type is multiple of the destination element type or the other
+/// way around.
+///
+/// NOTE: This method assumes that common conversion preconditions are met.
+static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
+                                                   VectorType srcType,
+                                                   VectorType dstType,
+                                                   Operation *op) {
+  unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
+  unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+  unsigned byteBitwidth = 8;
+
+  // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
+  if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+      (dstElemBitwidth % srcElemBitwidth) != 0)
+    return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+
+  return success();
+}
+
 SmallVector<BitCastRewriter::Metadata>
 BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
   SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +808,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
   return result;
 }
 
-Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
-                                   Value initialValue, Value runningResult,
-                                   const BitCastRewriter::Metadata &metadata) {
+Value BitCastRewriter::genericRewriteStep(
+    PatternRewriter &rewriter, Location loc, Value initialValue,
+    Value runningResult, const BitCastRewriter::Metadata &metadata) {
   // Create vector.shuffle from the metadata.
   auto shuffleOp = rewriter.create<vector::ShuffleOp>(
       loc, initialValue, initialValue, metadata.shuffles);
@@ -810,6 +843,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
   return runningResult;
 }
 
+/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+                                    Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(4) &&
+         "Expected i4 type");
+
+  // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+  int64_t vecDimSize = srcVecType.getShape().back();
+  SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+  constexpr int64_t i4Toi8BitwidthFactor = 2;
+  i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+  auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+  Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+  // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
+  // byte are place in one vector and the high i4 elements in another vector.
+  constexpr int8_t bitsToShift = 4;
+  auto shiftValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(i8VecType, bitsToShift));
+  Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
+  Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
+  Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
+
+  // 3. Interleave low and high i8 elements using a shuffle.
+  SmallVector<int64_t> interleaveMaskValues;
+  interleaveMaskValues.reserve(vecDimSize);
+  for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
+    interleaveMaskValues.push_back(i);
+    interleaveMaskValues.push_back(i + (vecDimSize / 2));
+  }
+
+  return rewriter.create<vector::ShuffleOp>(
+      loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+}
+
 namespace {
 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
 /// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +900,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     VectorType sourceVectorType = bitCastOp.getSourceVectorType();
     VectorType targetVectorType = bitCastOp.getResultVectorType();
     BitCastRewriter bcr(sourceVectorType, targetVectorType);
-    if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+    if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
       return failure();
 
     // Perform the rewrite.
@@ -839,8 +910,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     Value runningResult;
     for (const BitCastRewriter ::Metadata &metadata :
          bcr.precomputeMetadata(shuffledElementType)) {
-      runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
-                                      runningResult, metadata);
+      runningResult = bcr.genericRewriteStep(
+          rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
     }
 
     // Finalize the rewrite.
@@ -885,7 +956,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
     VectorType sourceVectorType = bitCastOp.getSourceVectorType();
     VectorType targetVectorType = bitCastOp.getResultVectorType();
     BitCastRewriter bcr(sourceVectorType, targetVectorType);
-    if (failed(bcr.precondition(
+    if (failed(bcr.commonPrecondition(
             rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
       return failure();
 
@@ -896,8 +967,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
         cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
     for (const BitCastRewriter::Metadata &metadata :
          bcr.precomputeMetadata(shuffledElementType)) {
-      runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
-                                      sourceValue, runningResult, metadata);
+      runningResult = bcr.genericRewriteStep(
+          rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
     }
 
     // Finalize the rewrite.
@@ -915,6 +986,52 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
     return success();
   }
 };
+
+/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+///   extsi vector<8xi4> -> vector<8xi32>
+///     is rewriten as
+///   sequence of shuffles and bitwise of for i4 -> i8
+///   extsi vector<8xi8> -> vector<8xi32>
+///
+///   sitofp vector<8xi4> -> vector<8xf32>
+///     is rewriten as
+///   sequence of shuffles and bitwise of for i4 -> i8
+///   sitofp vector<8xi8> -> vector<8xf32>
+///
+template <typename ConversionOpType>
+struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
+  using OpRewritePattern<ConversionOpType>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConversionOpType conversionOp,
+                                PatternRewriter &rewriter) const override {
+    // Set up the BitCastRewriter and verify the preconditions.
+    Value srcValue = conversionOp.getIn();
+    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+    auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+    if (failed(
+            commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
+      return failure();
+
+    // Check general alignment preconditions.
+    if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
+                                             conversionOp)))
+      return failure();
+
+    // Perform the rewrite.
+    Value subByteExt =
+        rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+
+    // Finalize the rewrite.
+    rewriter.replaceOpWithNewOp<ConversionOpType>(
+        conversionOp, conversionOp.getType(), subByteExt);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -936,4 +1053,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
   patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
                RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
                                                     benefit);
+
+  // Patterns for aligned cases. We set higher priority as they are expected to
+  // generate better performance for aligned cases.
+  patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
+               RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
+      patterns.getContext(), benefit.getBenefit() + 1);
 }
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index a600fa955b17003..c4fbb4c219b9170 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -193,6 +193,39 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
   return %1 : vector<8xi17>
 }
 
+// CHECK-LABEL: func.func @aligned_extsi(
+func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
+  // CHECK: arith.shli
+  // CHECK: arith.shrsi
+  // CHECK: arith.shrsi
+  // CHECK: vector.shuffle
+  // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+  %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
+  return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_base_case(
+func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
+  // CHECK: arith.shli
+  // CHECK: arith.shrsi
+  // CHECK: arith.shrsi
+  // CHECK: vector.shuffle
+  // CHECK-NOT: arith.extsi
+  %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
+  return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_sitofp(
+func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
+  // CHECK: arith.shli
+  // CHECK: arith.shrsi
+  // CHECK: arith.shrsi
+  // CHECK: shuffle
+  // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+  %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
     %f = transform.structured.match ops{["func.func"]} in %module_op

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to have simpler patterns for the simpler aligned cases.

Out of curiosity, how hard would it be to have foldings from the existing in MLIR to get to a similar form like you have now?

If you could paste some before / after IT in the comments (or the commit message), this would also be useful.

Thanks for improving this !

@banach-space
Copy link
Contributor

Thanks!

Any thoughts/plans for extending this to scalable vectors? Related discussion here: #79270

This PR adds new patterns to improve the generated vector code for the
emulation of any conversion that have to go through an i4 -> i8 type
extension (only signed extensions are supported for now). This will
impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4
-> f8/f16/f32/f64.

The asm code generated for the supported cases is significantly better
after this PR for both x86 and aarch64.
@dcaballe
Copy link
Contributor Author

Out of curiosity, how hard would it be to have foldings from the existing in MLIR to get to a similar form like you have now?

It seems complicated as the approach seems slightly different. We would have to look at multiple ops to realize that the first shuffle is redundant for cases that are multiple of the 8 bits ("aligned"). Then realize that some of the shifts are actually implementing the interleave of two register... I don't see a clear path...

Any thoughts/plans for extending this to scalable vectors? Related discussion here: #79270

This is mostly a workaround to keep things moving but ultimately we may want these simpler cases to be implemented in the backend (there were already a few comments about that in this file). It gets difficult to get this working for scalable at this level as we would have to introduce SVE or LLVM intrinsics to model the interleave in an scalable way. The current implementation is also not working for multi-dim vectors (multi-dim not supported by shuffle), which is another limitation that we are hitting at this level with this PR.

@MacDue
Copy link
Member

MacDue commented Jan 26, 2024

It gets difficult to get this working for scalable at this level as we would have to introduce SVE or LLVM intrinsics to model the interleave in an scalable way.

There already are LLVM intrinsics for that, so I don't think it'd be hard to extend to support SVE:

I wrote this little test, which seemed to build fine, and generate reasonable looking code:

func.func @test_sve_i4_extend(%inMem: memref<?xi4> ) -> vector<[8]xi32> {
  %c0 = arith.constant 0 :index
  %c4 = arith.constant 4 : i8
  %in = vector.load %inMem[%c0] :  memref<?xi4>, vector<[8]xi4>
  %shift = vector.splat %c4 : vector<[4]xi8>
  %0 = vector.bitcast %in : vector<[8]xi4> to vector<[4]xi8>
  %1 = arith.shli %0, %shift : vector<[4]xi8>
  %2 = arith.shrsi %1, %shift : vector<[4]xi8>
  %3 = arith.shrsi %0, %shift : vector<[4]xi8>
  %4 = "llvm.intr.experimental.vector.interleave2"(%2, %3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
  %5 = arith.extsi %4 : vector<[8]xi8> to vector<[8]xi32>
  return %5 : vector<[8]xi32>
}

->

test_sve_i4_extend: 
	ptrue	p0.s
	ld1sb	{ z0.s }, p0/z, [x1]
	lsl	z1.s, z0.s, #28
	asr	z0.s, z0.s, #4
	asr	z1.s, z1.s, #28
	zip2	z2.s, z1.s, z0.s
	zip1	z0.s, z1.s, z0.s
	movprfx	z1, z2
	sxtb	z1.s, p0/m, z2.s
	sxtb	z0.s, p0/m, z0.s
	ret

I think in the vector dialect "llvm.intr.experimental.vector.interleave2" could nicely become vector.scalable.interleave 🙂

@dcaballe
Copy link
Contributor Author

Thanks for the info! I think making the interleave op at Vector level available to fixed vectors would also make sense. There is a point in knowing that a shuffle is actually implementing an interleave pattern.

I guess we should also be fine with this LLVM limitations for now:

While this intrinsic supports all vector types the recommended way to express this operation for fixed-width vectors is still to use a shufflevector, as that may allow for more optimization opportunities.

Again, if looks like we are building a small ad-hoc backend in here. Ultimately we may want this to be properly supported in LLVM.

@dcaballe dcaballe merged commit a694104 into llvm:main Jan 30, 2024
4 checks passed
KoolJBlack added a commit that referenced this pull request May 1, 2024
…n emulation (#89131)

This PR builds on #79494 with an additional path for efficient unsigned `i4 ->i8` type extension for 1D/2D operations. This will impact any i4 -> i8/i16/i32/i64 unsigned extensions as well as sitofp i4 -> f8/f16/f32/f64.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants