-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Add patterns for efficient i4 -> i8 conversion emulation #79494
Conversation
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR adds new patterns to improve the generated vector code for the emulation of any conversion that have to go through an i4 -> i8 type extension (only signed extensions are supported for now). This will impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4 -> f8/f16/f32/f64. The asm code generated for the supported cases is significantly better after this PR for both x86 and aarch64. Full diff: https://github.com/llvm/llvm-project/pull/79494.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ead7d645cb5bb3d..fdc2d2d7e0f7fa6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -642,9 +642,9 @@ struct BitCastRewriter {
BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
- /// Verify that the preconditions for the rewrite are met.
- LogicalResult precondition(PatternRewriter &rewriter,
- VectorType preconditionVectorType, Operation *op);
+ /// Verify that general preconditions for the rewrite are met.
+ LogicalResult commonPrecondition(PatternRewriter &rewriter,
+ VectorType preconditionType, Operation *op);
/// Precompute the metadata for the rewrite.
SmallVector<BitCastRewriter::Metadata>
@@ -652,9 +652,9 @@ struct BitCastRewriter {
/// Rewrite one step of the sequence:
/// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
- Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
- Value runningResult,
- const BitCastRewriter::Metadata &metadata);
+ Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
+ Value initialValue, Value runningResult,
+ const BitCastRewriter::Metadata &metadata);
private:
/// Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,54 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
LDBG("\n" << enumerator.sourceElementRanges);
}
-LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
- VectorType precondition,
- Operation *op) {
- if (precondition.getRank() != 1 || precondition.isScalable())
+/// Verify that the precondition type meets the common preconditions for any
+/// conversion.
+static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
+ VectorType preconditionType,
+ Operation *op) {
+ if (preconditionType.getRank() != 1 || preconditionType.isScalable())
return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
- int64_t resultBitwidth = precondition.getElementTypeBitWidth();
+ unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
if (resultBitwidth % 8 != 0)
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
return success();
}
+LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
+ VectorType preconditionType,
+ Operation *op) {
+ if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
+ return rewriter.notifyMatchFailure(op, "types are not vector");
+
+ return commonConversionPrecondition(rewriter, preconditionType, op);
+}
+
+/// Verify that source and destination element types meet the precondition for
+/// the supported aligned conversion cases. Alignment means that the either the
+/// source element type is multiple of the destination element type or the other
+/// way around.
+///
+/// NOTE: This method assumes that common conversion preconditions are met.
+static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
+ VectorType srcType,
+ VectorType dstType,
+ Operation *op) {
+ unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
+ unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+ unsigned byteBitwidth = 8;
+
+ // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
+ if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+ (dstElemBitwidth % srcElemBitwidth) != 0)
+ return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+
+ return success();
+}
+
SmallVector<BitCastRewriter::Metadata>
BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +808,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
return result;
}
-Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
- Value initialValue, Value runningResult,
- const BitCastRewriter::Metadata &metadata) {
+Value BitCastRewriter::genericRewriteStep(
+ PatternRewriter &rewriter, Location loc, Value initialValue,
+ Value runningResult, const BitCastRewriter::Metadata &metadata) {
// Create vector.shuffle from the metadata.
auto shuffleOp = rewriter.create<vector::ShuffleOp>(
loc, initialValue, initialValue, metadata.shuffles);
@@ -810,6 +843,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
return runningResult;
}
+/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(4) &&
+ "Expected i4 type");
+
+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+ int64_t vecDimSize = srcVecType.getShape().back();
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i4Toi8BitwidthFactor = 2;
+ i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
+ // byte are place in one vector and the high i4 elements in another vector.
+ constexpr int8_t bitsToShift = 4;
+ auto shiftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(i8VecType, bitsToShift));
+ Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
+ Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
+ Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
+
+ // 3. Interleave low and high i8 elements using a shuffle.
+ SmallVector<int64_t> interleaveMaskValues;
+ interleaveMaskValues.reserve(vecDimSize);
+ for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
+ interleaveMaskValues.push_back(i);
+ interleaveMaskValues.push_back(i + (vecDimSize / 2));
+ }
+
+ return rewriter.create<vector::ShuffleOp>(
+ loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+}
+
namespace {
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +900,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
VectorType targetVectorType = bitCastOp.getResultVectorType();
BitCastRewriter bcr(sourceVectorType, targetVectorType);
- if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+ if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
return failure();
// Perform the rewrite.
@@ -839,8 +910,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
Value runningResult;
for (const BitCastRewriter ::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
- runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
- runningResult, metadata);
+ runningResult = bcr.genericRewriteStep(
+ rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
}
// Finalize the rewrite.
@@ -885,7 +956,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
VectorType targetVectorType = bitCastOp.getResultVectorType();
BitCastRewriter bcr(sourceVectorType, targetVectorType);
- if (failed(bcr.precondition(
+ if (failed(bcr.commonPrecondition(
rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
return failure();
@@ -896,8 +967,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
for (const BitCastRewriter::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
- runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
- sourceValue, runningResult, metadata);
+ runningResult = bcr.genericRewriteStep(
+ rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
}
// Finalize the rewrite.
@@ -915,6 +986,52 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
return success();
}
};
+
+/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+/// extsi vector<8xi4> -> vector<8xi32>
+/// is rewriten as
+/// sequence of shuffles and bitwise of for i4 -> i8
+/// extsi vector<8xi8> -> vector<8xi32>
+///
+/// sitofp vector<8xi4> -> vector<8xf32>
+/// is rewriten as
+/// sequence of shuffles and bitwise of for i4 -> i8
+/// sitofp vector<8xi8> -> vector<8xf32>
+///
+template <typename ConversionOpType>
+struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
+ using OpRewritePattern<ConversionOpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConversionOpType conversionOp,
+ PatternRewriter &rewriter) const override {
+ // Set up the BitCastRewriter and verify the preconditions.
+ Value srcValue = conversionOp.getIn();
+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+ auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+ if (failed(
+ commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
+ return failure();
+
+ // Check general alignment preconditions.
+ if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
+ conversionOp)))
+ return failure();
+
+ // Perform the rewrite.
+ Value subByteExt =
+ rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+
+ // Finalize the rewrite.
+ rewriter.replaceOpWithNewOp<ConversionOpType>(
+ conversionOp, conversionOp.getType(), subByteExt);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -936,4 +1053,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
benefit);
+
+ // Patterns for aligned cases. We set higher priority as they are expected to
+ // generate better performance for aligned cases.
+ patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
+ patterns.getContext(), benefit.getBenefit() + 1);
}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index a600fa955b17003..c4fbb4c219b9170 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -193,6 +193,39 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
return %1 : vector<8xi17>
}
+// CHECK-LABEL: func.func @aligned_extsi(
+func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
+ // CHECK: arith.shli
+ // CHECK: arith.shrsi
+ // CHECK: arith.shrsi
+ // CHECK: vector.shuffle
+ // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+ %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_base_case(
+func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
+ // CHECK: arith.shli
+ // CHECK: arith.shrsi
+ // CHECK: arith.shrsi
+ // CHECK: vector.shuffle
+ // CHECK-NOT: arith.extsi
+ %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_sitofp(
+func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
+ // CHECK: arith.shli
+ // CHECK: arith.shrsi
+ // CHECK: arith.shrsi
+ // CHECK: shuffle
+ // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+ %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
@llvm/pr-subscribers-mlir-vector Author: Diego Caballero (dcaballe) ChangesThis PR adds new patterns to improve the generated vector code for the emulation of any conversion that have to go through an i4 -> i8 type extension (only signed extensions are supported for now). This will impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4 -> f8/f16/f32/f64. The asm code generated for the supported cases is significantly better after this PR for both x86 and aarch64. Full diff: https://github.com/llvm/llvm-project/pull/79494.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ead7d645cb5bb3d..fdc2d2d7e0f7fa6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -642,9 +642,9 @@ struct BitCastRewriter {
BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
- /// Verify that the preconditions for the rewrite are met.
- LogicalResult precondition(PatternRewriter &rewriter,
- VectorType preconditionVectorType, Operation *op);
+ /// Verify that general preconditions for the rewrite are met.
+ LogicalResult commonPrecondition(PatternRewriter &rewriter,
+ VectorType preconditionType, Operation *op);
/// Precompute the metadata for the rewrite.
SmallVector<BitCastRewriter::Metadata>
@@ -652,9 +652,9 @@ struct BitCastRewriter {
/// Rewrite one step of the sequence:
/// `(shuffle -> and -> shiftright -> shiftleft -> or)`.
- Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
- Value runningResult,
- const BitCastRewriter::Metadata &metadata);
+ Value genericRewriteStep(PatternRewriter &rewriter, Location loc,
+ Value initialValue, Value runningResult,
+ const BitCastRewriter::Metadata &metadata);
private:
/// Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,54 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
LDBG("\n" << enumerator.sourceElementRanges);
}
-LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
- VectorType precondition,
- Operation *op) {
- if (precondition.getRank() != 1 || precondition.isScalable())
+/// Verify that the precondition type meets the common preconditions for any
+/// conversion.
+static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
+ VectorType preconditionType,
+ Operation *op) {
+ if (preconditionType.getRank() != 1 || preconditionType.isScalable())
return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
- int64_t resultBitwidth = precondition.getElementTypeBitWidth();
+ unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
if (resultBitwidth % 8 != 0)
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
return success();
}
+LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
+ VectorType preconditionType,
+ Operation *op) {
+ if (!enumerator.sourceVectorType || !enumerator.targetVectorType)
+ return rewriter.notifyMatchFailure(op, "types are not vector");
+
+ return commonConversionPrecondition(rewriter, preconditionType, op);
+}
+
+/// Verify that source and destination element types meet the precondition for
+/// the supported aligned conversion cases. Alignment means that the either the
+/// source element type is multiple of the destination element type or the other
+/// way around.
+///
+/// NOTE: This method assumes that common conversion preconditions are met.
+static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
+ VectorType srcType,
+ VectorType dstType,
+ Operation *op) {
+ unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
+ unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
+ unsigned byteBitwidth = 8;
+
+ // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
+ if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+ (dstElemBitwidth % srcElemBitwidth) != 0)
+ return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
+
+ return success();
+}
+
SmallVector<BitCastRewriter::Metadata>
BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +808,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
return result;
}
-Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
- Value initialValue, Value runningResult,
- const BitCastRewriter::Metadata &metadata) {
+Value BitCastRewriter::genericRewriteStep(
+ PatternRewriter &rewriter, Location loc, Value initialValue,
+ Value runningResult, const BitCastRewriter::Metadata &metadata) {
// Create vector.shuffle from the metadata.
auto shuffleOp = rewriter.create<vector::ShuffleOp>(
loc, initialValue, initialValue, metadata.shuffles);
@@ -810,6 +843,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
return runningResult;
}
+/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(4) &&
+ "Expected i4 type");
+
+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
+ int64_t vecDimSize = srcVecType.getShape().back();
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i4Toi8BitwidthFactor = 2;
+ i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
+ // byte are place in one vector and the high i4 elements in another vector.
+ constexpr int8_t bitsToShift = 4;
+ auto shiftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(i8VecType, bitsToShift));
+ Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
+ Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
+ Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
+
+ // 3. Interleave low and high i8 elements using a shuffle.
+ SmallVector<int64_t> interleaveMaskValues;
+ interleaveMaskValues.reserve(vecDimSize);
+ for (int i = 0, end = vecDimSize / 2; i < end; ++i) {
+ interleaveMaskValues.push_back(i);
+ interleaveMaskValues.push_back(i + (vecDimSize / 2));
+ }
+
+ return rewriter.create<vector::ShuffleOp>(
+ loc, low, high, rewriter.getI64ArrayAttr(interleaveMaskValues));
+}
+
namespace {
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +900,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
VectorType targetVectorType = bitCastOp.getResultVectorType();
BitCastRewriter bcr(sourceVectorType, targetVectorType);
- if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+ if (failed(bcr.commonPrecondition(rewriter, targetVectorType, bitCastOp)))
return failure();
// Perform the rewrite.
@@ -839,8 +910,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
Value runningResult;
for (const BitCastRewriter ::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
- runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
- runningResult, metadata);
+ runningResult = bcr.genericRewriteStep(
+ rewriter, bitCastOp->getLoc(), truncValue, runningResult, metadata);
}
// Finalize the rewrite.
@@ -885,7 +956,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
VectorType sourceVectorType = bitCastOp.getSourceVectorType();
VectorType targetVectorType = bitCastOp.getResultVectorType();
BitCastRewriter bcr(sourceVectorType, targetVectorType);
- if (failed(bcr.precondition(
+ if (failed(bcr.commonPrecondition(
rewriter, cast<VectorType>(extOp.getOut().getType()), bitCastOp)))
return failure();
@@ -896,8 +967,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
cast<IntegerType>(getElementTypeOrSelf(sourceValue.getType()));
for (const BitCastRewriter::Metadata &metadata :
bcr.precomputeMetadata(shuffledElementType)) {
- runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(),
- sourceValue, runningResult, metadata);
+ runningResult = bcr.genericRewriteStep(
+ rewriter, bitCastOp->getLoc(), sourceValue, runningResult, metadata);
}
// Finalize the rewrite.
@@ -915,6 +986,52 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
return success();
}
};
+
+/// Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+/// extsi vector<8xi4> -> vector<8xi32>
+/// is rewriten as
+/// sequence of shuffles and bitwise of for i4 -> i8
+/// extsi vector<8xi8> -> vector<8xi32>
+///
+/// sitofp vector<8xi4> -> vector<8xf32>
+/// is rewriten as
+/// sequence of shuffles and bitwise of for i4 -> i8
+/// sitofp vector<8xi8> -> vector<8xf32>
+///
+template <typename ConversionOpType>
+struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
+ using OpRewritePattern<ConversionOpType>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ConversionOpType conversionOp,
+ PatternRewriter &rewriter) const override {
+ // Set up the BitCastRewriter and verify the preconditions.
+ Value srcValue = conversionOp.getIn();
+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+ auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
+ if (failed(
+ commonConversionPrecondition(rewriter, dstVecType, conversionOp)))
+ return failure();
+
+ // Check general alignment preconditions.
+ if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
+ conversionOp)))
+ return failure();
+
+ // Perform the rewrite.
+ Value subByteExt =
+ rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+
+ // Finalize the rewrite.
+ rewriter.replaceOpWithNewOp<ConversionOpType>(
+ conversionOp, conversionOp.getType(), subByteExt);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -936,4 +1053,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
patterns.add<RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext(),
benefit);
+
+ // Patterns for aligned cases. We set higher priority as they are expected to
+ // generate better performance for aligned cases.
+ patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
+ patterns.getContext(), benefit.getBenefit() + 1);
}
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index a600fa955b17003..c4fbb4c219b9170 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -193,6 +193,39 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
return %1 : vector<8xi17>
}
+// CHECK-LABEL: func.func @aligned_extsi(
+func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
+ // CHECK: arith.shli
+ // CHECK: arith.shrsi
+ // CHECK: arith.shrsi
+ // CHECK: vector.shuffle
+ // CHECK: arith.extsi %{{.*}} : vector<8xi8> to vector<8xi32>
+ %0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_base_case(
+func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
+ // CHECK: arith.shli
+ // CHECK: arith.shrsi
+ // CHECK: arith.shrsi
+ // CHECK: vector.shuffle
+ // CHECK-NOT: arith.extsi
+ %0 = arith.extsi %a : vector<8xi4> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_sitofp(
+func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
+ // CHECK: arith.shli
+ // CHECK: arith.shrsi
+ // CHECK: arith.shrsi
+ // CHECK: shuffle
+ // CHECK: arith.sitofp %{{.*}} : vector<8xi8> to vector<8xf32>
+ %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
%f = transform.structured.match ops{["func.func"]} in %module_op
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to have simpler patterns for the simpler aligned cases.
Out of curiosity, how hard would it be to have foldings from the existing in MLIR to get to a similar form like you have now?
If you could paste some before / after IT in the comments (or the commit message), this would also be useful.
Thanks for improving this !
Thanks! Any thoughts/plans for extending this to scalable vectors? Related discussion here: #79270 |
This PR adds new patterns to improve the generated vector code for the emulation of any conversion that have to go through an i4 -> i8 type extension (only signed extensions are supported for now). This will impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4 -> f8/f16/f32/f64. The asm code generated for the supported cases is significantly better after this PR for both x86 and aarch64.
df1cd5d
to
b8fb65d
Compare
It seems complicated as the approach seems slightly different. We would have to look at multiple ops to realize that the first shuffle is redundant for cases that are multiple of the 8 bits ("aligned"). Then realize that some of the shifts are actually implementing the interleave of two register... I don't see a clear path...
This is mostly a workaround to keep things moving but ultimately we may want these simpler cases to be implemented in the backend (there were already a few comments about that in this file). It gets difficult to get this working for scalable at this level as we would have to introduce SVE or LLVM intrinsics to model the interleave in an scalable way. The current implementation is also not working for multi-dim vectors (multi-dim not supported by shuffle), which is another limitation that we are hitting at this level with this PR. |
There already are LLVM intrinsics for that, so I don't think it'd be hard to extend to support SVE: I wrote this little test, which seemed to build fine, and generate reasonable looking code: func.func @test_sve_i4_extend(%inMem: memref<?xi4> ) -> vector<[8]xi32> {
%c0 = arith.constant 0 :index
%c4 = arith.constant 4 : i8
%in = vector.load %inMem[%c0] : memref<?xi4>, vector<[8]xi4>
%shift = vector.splat %c4 : vector<[4]xi8>
%0 = vector.bitcast %in : vector<[8]xi4> to vector<[4]xi8>
%1 = arith.shli %0, %shift : vector<[4]xi8>
%2 = arith.shrsi %1, %shift : vector<[4]xi8>
%3 = arith.shrsi %0, %shift : vector<[4]xi8>
%4 = "llvm.intr.experimental.vector.interleave2"(%2, %3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
%5 = arith.extsi %4 : vector<[8]xi8> to vector<[8]xi32>
return %5 : vector<[8]xi32>
} ->
I think in the vector dialect |
Thanks for the info! I think making the interleave op at Vector level available to fixed vectors would also make sense. There is a point in knowing that a shuffle is actually implementing an interleave pattern. I guess we should also be fine with this LLVM limitations for now:
Again, if looks like we are building a small ad-hoc backend in here. Ultimately we may want this to be properly supported in LLVM. |
This PR adds new patterns to improve the generated vector code for the emulation of any conversion that have to go through an i4 -> i8 type extension (only signed extensions are supported for now). This will impact any i4 -> i8/i16/i32/i64 signed extensions as well as sitofp i4 -> f8/f16/f32/f64.
The asm code generated for the supported cases is significantly better after this PR for both x86 and aarch64.