Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… #65774

Closed
wants to merge 3 commits into from

Conversation

nicolasvasilache
Copy link
Contributor

@nicolasvasilache nicolasvasilache commented Sep 8, 2023

…cast) expansion

This revision adds a rewrite for sequences of vector ext(maybe_broadcast(bitcast))
to use a more efficient sequence of vector operations comprising shuffle, shift and
bitwise ops.

The rewrite uses an intermediate bitwidth equal to the licm of
the element types of the source and result types of bitCastOp. This
intermediate type may be small or greater than the desired elemental type of
the ext, in which case appropriate ext or trunc operations are inserted.

The rewrite fails if the intermediate type is greater than 64 and if the
involved vector types fail to meet basic divisilibity requirements. In other
words, this rewrite does not handle partial vector boundaries and leaves
this part of the heavy-lifting to LLVM.

In the future, it may be relevant to give control on the size of the intermediate type.
For now, it is empirically determined that taking 64 result in much better assembly
being produced when piping through llvm-mca.

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 12, 2023

@llvm/pr-subscribers-mlir-vector

Changes

…cast) expansion

This revision adds a rewrite for sequences of vector ext(maybe_broadcast(bitcast)) to use a more efficient sequence of vector operations comprising shuffles, shifts and bitwise logical ops. The rewrite uses an intermediate bitwidth equal to the licm of the element types of the source and result types of bitCastOp. This intermediate type may be small or greater than the desired elemental type of the extOp, in which case appropriate ext or trunc operations are inserted. The rewrite fails if the intermediate type is greater than 64 and if the involved vector types fail to meet basic divisilibity requirements. In other words, this rewrite does not handle partial vector boundaries and leaves this part of the heavy-lifting to LLVM.

Patch is 29.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/65774.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+13)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+21-1)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+10)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+5)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+276-4)
  • (modified) mlir/test/Dialect/LLVM/transform-e2e.mlir (-21)
  • (added) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+205)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 2b8c95a94257e6c..d1cef91f8e27525 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -281,6 +281,19 @@ def ApplyLowerTransposePatternsOp : Op]> {
+  let description = [{
+    Indicates that vector narrow rewrite operations should be applied.
+
+    This is usually a late step that is run after bufferization as part of the
+    process of lowering to e.g. LLVM or NVVM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplySplitTransferFullPartialPatternsOp : Op]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index c644090d8c78cd0..20c33921f9de24e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -23,6 +23,7 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arith {
+class AndIOp;
 class NarrowTypeEmulationConverter;
 } // namespace arith
 
@@ -143,7 +144,7 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
 
 /// Patterns that remove redundant vector broadcasts.
 void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns,
-                                          PatternBenefit benefit = 1);
+                                         PatternBenefit benefit = 1);
 
 /// Populate `patterns` with the following patterns.
 ///
@@ -301,6 +302,25 @@ void populateVectorNarrowTypeEmulationPatterns(
     arith::NarrowTypeEmulationConverter &typeConverter,
     RewritePatternSet &patterns);
 
+/// Rewrite vector ext(maybe_broadcast(bitcast)) to use a more efficient
+/// sequence of vector operations comprising shuffles, shifts and bitwise
+/// logical ops. The rewrite uses an intermediate bitwidth equal to the licm of
+/// the element types of the source and result types of `bitCastOp`. This
+/// intermediate type may be small or greater than the desired elemental type of
+/// the extOp, in which case appropriate ext or trunc operations are inserted.
+/// The rewrite fails if the intermediate type is greater than 64 and if the
+/// involved vector types fail to meet basic divisilibity requirements. In other
+/// words, this rewrite does not handle partial vector boundaries and leaves
+/// this part of the heavy-lifting to LLVM.
+FailureOr rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+                                     vector::BitCastOp bitCastOp,
+                                     vector::BroadcastOp maybeBroadcastOp);
+
+/// Appends patterns for rewriting vector operations over narrow types with
+/// ops over wider types.
+void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
+                                             PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
 
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index f031eb0a5c30ce9..9df5548cd5d939c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -357,6 +357,16 @@ class VectorType::Builder {
     return *this;
   }
 
+  /// Set a dim in shape @pos to val.
+  Builder &setDim(unsigned pos, int64_t val) {
+    if (storage.empty())
+      storage.append(shape.begin(), shape.end());
+    assert(pos < storage.size() && "overflow");
+    storage[pos] = val;
+    shape = {storage.data(), storage.size()};
+    return *this;
+  }
+
   operator VectorType() {
     return VectorType::get(shape, elementType, scalableDims);
   }
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 94f19e59669eafd..0fdeded436a9773 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -154,6 +154,11 @@ void transform::ApplyLowerTransposePatternsOp::populatePatterns(
   }
 }
 
+void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  populateVectorNarrowTypeRewritePatterns(patterns);
+}
+
 void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::VectorTransformsOptions vectorTransformOptions;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index b2b7bfc5e4437c1..d10994f3709e390 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -7,7 +7,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
@@ -15,13 +14,29 @@
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
-#include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
-#include 
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include 
+#include 
+#include 
 
 using namespace mlir;
 
+#define DEBUG_TYPE "vector-narrow-type-emulation"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define DBGSNL() (llvm::dbgs() << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -155,6 +170,256 @@ struct ConvertVectorTransferRead final
 };
 } // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// RewriteExtOfBitCast
+//===----------------------------------------------------------------------===//
+
+/// Create a vector of bit masks: `idx .. idx + step - 1` and broadcast it
+/// `numOccurrences` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0xc, 0x3, 0xc, 0x3, 0xc, 0x3].
+static SmallVector computeExtOfBitCastMasks(MLIRContext *ctx,
+                                                       int64_t bitwidth,
+                                                       int64_t step,
+                                                       int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+  SmallVector tmpMasks;
+  tmpMasks.reserve(bitwidth / step);
+  // Create a vector of bit masks: `idx .. idx + step - 1`.
+  for (int64_t idx = 0; idx < bitwidth; idx += step) {
+    LDBG("Mask bits " << idx << " .. " << idx + step - 1 << " out of "
+                      << bitwidth);
+    IntegerAttr mask = IntegerAttr::get(
+        interimIntType, llvm::APInt::getBitsSet(bitwidth, idx, idx + step));
+    tmpMasks.push_back(mask);
+  }
+  // Replicate the vector of bit masks to the desired size.
+  SmallVector masks;
+  masks.reserve(numOccurrences * tmpMasks.size());
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(masks, tmpMasks);
+  return masks;
+}
+
+/// Create a vector of bit shifts by `k * idx` and broadcast it `numOccurrences`
+/// times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x2, 0x0, 0x2, 0x0, 0x2].
+static SmallVector
+computeExtOfBitCastShifts(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+                          int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  IntegerType interimIntType = IntegerType::get(ctx, bitwidth);
+  SmallVector tmpShifts;
+  for (int64_t idx = 0; idx < bitwidth; idx += step) {
+    IntegerAttr shift = IntegerAttr::get(interimIntType, idx);
+    tmpShifts.push_back(shift);
+  }
+  SmallVector shifts;
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(shifts, tmpShifts);
+  return shifts;
+}
+
+/// Create a vector of bit shuffles: `numOccurrences * idx` and broadcast it
+/// `bitwidth/step` times.
+/// `step` must divide `bitwidth` evenly.
+/// Example: (4, 2, 3) -> [0x0, 0x1, 0x0, 0x1, 0x0, 0x1].
+static SmallVector
+computeExtOfBitCastShuffles(MLIRContext *ctx, int64_t bitwidth, int64_t step,
+                            int64_t numOccurrences) {
+  assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
+  SmallVector shuffles;
+  int64_t n = floorDiv(bitwidth, step);
+  for (int64_t idx = 0; idx < numOccurrences; ++idx)
+    llvm::append_range(shuffles, SmallVector(n, idx));
+  return shuffles;
+}
+
+/// Compute the intermediate vector type, its elemental type must be an integer
+/// with bitwidth that:
+///   1. is smaller than 64 (TODO: in the future we may want target-specific
+///   control).
+///   2. divides sourceBitWidth * mostMinorSourceDim
+static int64_t computeExtOfBitCastBitWidth(int64_t sourceBitWidth,
+                                           int64_t mostMinorSourceDim,
+                                           int64_t targetBitWidth) {
+  for (int64_t mult : {32, 16, 8, 4, 2, 1}) {
+    int64_t interimBitWidth =
+        std::lcm(mult, std::lcm(sourceBitWidth, targetBitWidth));
+    if (interimBitWidth > 64)
+      continue;
+    if ((sourceBitWidth * mostMinorSourceDim) % interimBitWidth != 0)
+      continue;
+    return interimBitWidth;
+  }
+  return 0;
+}
+
+FailureOr
+mlir::vector::rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+                                  vector::BitCastOp bitCastOp,
+                                  vector::BroadcastOp maybeBroadcastOp) {
+  assert(
+      (llvm::isa(extOp) || llvm::isa(extOp)) &&
+      "unsupported op");
+
+  // The bitcast op is the load-bearing part, capture the source and bitCast
+  // types as well as bitwidth and most minor dimension.
+  VectorType sourceVectorType = bitCastOp.getSourceVectorType();
+  int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth();
+  int64_t mostMinorSourceDim = sourceVectorType.getShape().back();
+  LDBG("sourceVectorType: " << sourceVectorType);
+
+  VectorType bitCastVectorType = bitCastOp.getResultVectorType();
+  int64_t targetBitWidth = bitCastVectorType.getElementTypeBitWidth();
+  LDBG("bitCastVectorType: " << bitCastVectorType);
+
+  int64_t interimBitWidth = computeExtOfBitCastBitWidth(
+      sourceBitWidth, mostMinorSourceDim, targetBitWidth);
+  LDBG("interimBitWidth: " << interimBitWidth);
+  if (!interimBitWidth) {
+    return rewriter.notifyMatchFailure(
+        extOp, "heuristic could not find a reasonable interim bitwidth");
+  }
+  if (sourceBitWidth == interimBitWidth || targetBitWidth == interimBitWidth) {
+    return rewriter.notifyMatchFailure(
+        extOp, "interim bitwidth is equal to source or target, nothing to do");
+  }
+
+  int64_t interimMostMinorDim =
+      sourceBitWidth * mostMinorSourceDim / interimBitWidth;
+  LDBG("interimMostMinorDim: " << interimMostMinorDim);
+
+  Location loc = extOp->getLoc();
+  MLIRContext *ctx = extOp->getContext();
+
+  VectorType interimVectorType =
+      VectorType::Builder(sourceVectorType)
+          .setDim(sourceVectorType.getRank() - 1, interimMostMinorDim)
+          .setElementType(IntegerType::get(ctx, interimBitWidth));
+  LDBG("interimVectorType: " << interimVectorType);
+
+  IntegerType interimIntType = IntegerType::get(ctx, interimBitWidth);
+  VectorType vt =
+      VectorType::Builder(bitCastVectorType).setElementType(interimIntType);
+
+  // Rewrite the original bitcast to the interim vector type and shuffle to
+  // broadcast to the desired size.
+  auto newBitCastOp = rewriter.create(loc, interimVectorType,
+                                                         bitCastOp.getSource());
+  SmallVector shuffles = computeExtOfBitCastShuffles(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto shuffleOp = rewriter.create(loc, newBitCastOp,
+                                                      newBitCastOp, shuffles);
+  LDBG("shuffle: " << shuffleOp);
+
+  // Compute the constants for masking.
+  SmallVector masks = computeExtOfBitCastMasks(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto maskConstantOp = rewriter.create(
+      loc, DenseElementsAttr::get(vt, masks));
+  LDBG("maskConstant: " << maskConstantOp);
+  auto andOp = rewriter.create(loc, shuffleOp, maskConstantOp);
+  LDBG("andOp: " << andOp);
+
+  // Preserve the intermediate type: this may have serious consequences on the
+  // backend's ability to generate efficient vector operations.
+  // For instance on x86, converting to f16 without going through i32 has severe
+  // performance implications.
+  // As a consequence, this pattern must preserve the original behavior.
+  VectorType resultType = cast(extOp->getResultTypes().front());
+  Type resultElementType = getElementTypeOrSelf(resultType);
+  SmallVector shifts = computeExtOfBitCastShifts(
+      ctx, interimBitWidth, targetBitWidth, interimMostMinorDim);
+  auto shiftConstantOp = rewriter.create(
+      loc, DenseElementsAttr::get(vt, shifts));
+  LDBG("shiftConstant: " << shiftConstantOp);
+  Value newResult =
+      TypeSwitch(extOp)
+          .template Case([&](arith::ExtSIOp op) {
+            Value shifted =
+                rewriter.create(loc, andOp, shiftConstantOp);
+            auto vt = shifted.getType().cast();
+            VectorType extVt =
+                VectorType::Builder(vt).setElementType(resultElementType);
+            Operation *res =
+                (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+                    ? rewriter.create(loc, extVt, shifted)
+                    : rewriter.create(loc, extVt, shifted);
+            return res->getResult(0);
+          })
+          .template Case([&](arith::ExtUIOp op) {
+            Value shifted =
+                rewriter.create(loc, andOp, shiftConstantOp);
+            auto vt = shifted.getType().cast();
+            VectorType extVt =
+                VectorType::Builder(vt).setElementType(resultElementType);
+            Operation *res =
+                (resultElementType.getIntOrFloatBitWidth() > interimBitWidth)
+                    ? rewriter.create(loc, extVt, shifted)
+                    : rewriter.create(loc, extVt, shifted);
+            return res->getResult(0);
+          })
+          .Default([&](Operation *op) {
+            llvm_unreachable("unexpected op type");
+            return nullptr;
+          });
+
+  if (maybeBroadcastOp) {
+    newResult =
+        rewriter.create(loc, resultType, newResult);
+  }
+
+  return newResult;
+}
+
+namespace {
+template 
+struct RewriteExtOfBitCast : OpRewritePattern {
+  using OpRewritePattern::OpRewritePattern;
+
+  RewriteExtOfBitCast(MLIRContext *context, PatternBenefit benefit)
+      : OpRewritePattern(context, benefit) {}
+
+  LogicalResult matchAndRewrite(ExtOpType extOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultTy = dyn_cast(extOp.getType());
+    if (!resultTy)
+      return rewriter.notifyMatchFailure(extOp, "not a vector type");
+
+    int64_t elementalBitWidth = resultTy.getElementTypeBitWidth();
+    if (elementalBitWidth & (elementalBitWidth - 1)) {
+      return rewriter.notifyMatchFailure(
+          extOp, "result bitwidth must be a power of 2");
+    }
+
+    // Provision for a potential broadcast op that will be rewritten late.
+    auto maybeBroadcastOp =
+        extOp.getIn().template getDefiningOp();
+
+    // The source must be a bitcast op.
+    auto bitCastOp =
+        maybeBroadcastOp
+            ? maybeBroadcastOp.getSource()
+                  .template getDefiningOp()
+            : extOp.getIn().template getDefiningOp();
+    if (!bitCastOp)
+      return rewriter.notifyMatchFailure(extOp, "not a bitcast source");
+
+    // Try to rewrite.
+    FailureOr result =
+        rewriteExtOfBitCast(rewriter, extOp, bitCastOp, maybeBroadcastOp);
+    if (failed(result))
+      return failure();
+
+    rewriter.replaceOp(extOp, *result);
+    return success();
+  }
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Public Interface Definition
 //===----------------------------------------------------------------------===//
@@ -167,3 +432,10 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
   patterns.add(
       typeConverter, patterns.getContext());
 }
+
+void vector::populateVectorNarrowTypeRewritePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add,
+               RewriteExtOfBitCast>(patterns.getContext(),
+                                                    benefit);
+}
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index 777de75b1a47acc..2cb753a3d7fb8f3 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -29,33 +29,12 @@ transform.sequence failures(propagate) {
   // lowering TD macros.
   transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.transfer_permutation_patterns
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
-  } : !transform.any_op
-
-  transform.apply_patterns to %f {
     transform.apply_patterns.vector.lower_shape_cast
-  } : !transform.any_op
-
-  transform.apply_pattern...

@@ -301,6 +302,25 @@ void populateVectorNarrowTypeEmulationPatterns(
arith::NarrowTypeEmulationConverter &typeConverter,
RewritePatternSet &patterns);

/// Rewrite vector ext(maybe_broadcast(bitcast)) to use a more efficient
/// sequence of vector operations comprising shuffles, shifts and bitwise
/// logical ops. The rewrite uses an intermediate bitwidth equal to the licm of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: licm->lcm.
Same in the commit description.

(Did you teach autocorrect about loop transformation names? 😆 )

// RewriteExtOfBitCast
//===----------------------------------------------------------------------===//

/// Create a vector of bit masks: `idx .. idx + step - 1` and broadcast it
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear what idx means in this comment.

A second example could also be helpful to understand the logic generalization.

int64_t numOccurrences) {
assert(bitwidth % step == 0 && "step must divide bitwidth evenly");
SmallVector<int64_t> shuffles;
int64_t n = floorDiv(bitwidth, step);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why floorDiv if step is known to evenly divide bitwidth given the assert above?

vector::BitCastOp bitCastOp,
vector::BroadcastOp maybeBroadcastOp) {
assert(
(llvm::isa<arith::ExtSIOp>(extOp) || llvm::isa<arith::ExtUIOp>(extOp)) &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: no need to prefix with llvm::, this name is re-exported.
Also nit: isa is a variadic template, isa<arith::ExtSIOp, ExtUIOp>() should work.

.setElementType(IntegerType::get(ctx, interimBitWidth));
LDBG("interimVectorType: " << interimVectorType);

IntegerType interimIntType = IntegerType::get(ctx, interimBitWidth);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: put above interimVectorType and use in its construction.

LDBG("shiftConstant: " << shiftConstantOp);
Value newResult =
TypeSwitch<Operation *, Value>(extOp)
.template Case<arith::ExtSIOp>([&](arith::ExtSIOp op) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I don't think .template is needed here.

: rewriter.create<arith::TruncIOp>(loc, extVt, shifted);
return res->getResult(0);
})
.template Case<arith::ExtUIOp>([&](arith::ExtUIOp op) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you wanted to go full template here, you could do something like

template <typename ExtOp,
    typename ShiftOp = std::conditional_t<std::is_same_v<ExtOp, arith::ExtUIOp>, arith::ShRUIOp, arith::ShRSIOp>>
static FailureOr<Value>
rewriteExtOfBitCastImpl (RewriterBase &rewriter, ExtOp op, ...) {
  // ...
  Value shifted = rewriter.template create<ShiftOp>(...);
  // ...
}

return nullptr;
});

if (maybeBroadcastOp) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: return early instead.

return rewriter.notifyMatchFailure(extOp, "not a vector type");

int64_t elementalBitWidth = resultTy.getElementTypeBitWidth();
if (elementalBitWidth & (elementalBitWidth - 1)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd rather use llvm::isPowerOf2_64

func.func @f1(%m: !mst, %idx : index, %mf: !mtt) {

// CHECK: %[[MASK:.*]] = arith.constant dense<[
// CHECK-SAME-COUNT-6: 7, 56, 448, 3584, 28672, 229376, 1835008, 14680064, 117440512, 939524096, 7516192768, 60129542144, 481036337152, 3848290697216, 30786325577728, -35184372088832
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment with the hex version of this would be helpful.

Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
+1 on @ftynse comments plus a few of my own :).

/// Create a vector of bit shuffles: `numOccurrences * idx` and broadcast it
/// `bitwidth/step` times.
/// `step` must divide `bitwidth` evenly.
/// Example: (4, 2, 3) -> [0x0, 0x1, 0x0, 0x1, 0x0, 0x1].
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a hard time reconciling the comment and the example.

For instance if we broadcast bitwidth/step times, we should only have 2 reps, not 3.

@@ -29,33 +29,12 @@ transform.sequence failures(propagate) {
// lowering TD macros.
transform.apply_patterns to %f {
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
} : !transform.any_op

transform.apply_patterns to %f {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are unrelated, right?

Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the review is in good hands, so I'll just +1 having this. Thank you!

…(trunci) expansion

This revision adds a rewrite for sequences of vector `bitcast(trunci)` to use a more efficient sequence of
vector operations comprising `shuffle` and `bitwise` ops.

Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect.

The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance
in the pre-trunci vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori`
followed by an optional final `trunci`/`extui`.

The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is
not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type
is not an even number of bytes, further complexities arise that are not improved by this pattern:
the heavy lifting still needs to be done by LLVM.
@nicolasvasilache
Copy link
Contributor Author

I need to rework this to reuse the same implementation as #66387 which is both more general and more performant overall.

@nicolasvasilache nicolasvasilache marked this pull request as draft September 15, 2023 13:40
…cast) expansion

This revision adds a rewrite for sequences of vector `ext(bitcast)` to use a more efficient sequence of
vector operations comprising `shuffle` and `bitwise` ops.

Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect.

The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance
in the source vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori`
with shifts`.

The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is
not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type
is not an even number of bytes, further complexities arise that are not improved by this pattern:
the heavy lifting still needs to be done by LLVM.
@nicolasvasilache
Copy link
Contributor Author

This is now reimplemented in terms of #66387, closing this and starting another PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants