-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… #65774
Conversation
fde4574
to
11caf7c
Compare
@llvm/pr-subscribers-mlir-vector Changes…cast) expansion This revision adds a rewrite for sequences of vector ext(maybe_broadcast(bitcast)) to use a more efficient sequence of vector operations comprising shuffles, shifts and bitwise logical ops. The rewrite uses an intermediate bitwidth equal to the licm of the element types of the source and result types of
|
11caf7c
to
6f1997f
Compare
@@ -301,6 +302,25 @@ void populateVectorNarrowTypeEmulationPatterns( | |||
arith::NarrowTypeEmulationConverter &typeConverter, | |||
RewritePatternSet &patterns); | |||
|
|||
/// Rewrite vector ext(maybe_broadcast(bitcast)) to use a more efficient | |||
/// sequence of vector operations comprising shuffles, shifts and bitwise | |||
/// logical ops. The rewrite uses an intermediate bitwidth equal to the licm of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: licm->lcm.
Same in the commit description.
(Did you teach autocorrect about loop transformation names? 😆 )
// RewriteExtOfBitCast | ||
//===----------------------------------------------------------------------===// | ||
|
||
/// Create a vector of bit masks: `idx .. idx + step - 1` and broadcast it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unclear what idx
means in this comment.
A second example could also be helpful to understand the logic generalization.
int64_t numOccurrences) { | ||
assert(bitwidth % step == 0 && "step must divide bitwidth evenly"); | ||
SmallVector<int64_t> shuffles; | ||
int64_t n = floorDiv(bitwidth, step); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why floorDiv if step
is known to evenly divide bitwidth
given the assert above?
vector::BitCastOp bitCastOp, | ||
vector::BroadcastOp maybeBroadcastOp) { | ||
assert( | ||
(llvm::isa<arith::ExtSIOp>(extOp) || llvm::isa<arith::ExtUIOp>(extOp)) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: no need to prefix with llvm::
, this name is re-exported.
Also nit: isa
is a variadic template, isa<arith::ExtSIOp, ExtUIOp>()
should work.
.setElementType(IntegerType::get(ctx, interimBitWidth)); | ||
LDBG("interimVectorType: " << interimVectorType); | ||
|
||
IntegerType interimIntType = IntegerType::get(ctx, interimBitWidth); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: put above interimVectorType
and use in its construction.
LDBG("shiftConstant: " << shiftConstantOp); | ||
Value newResult = | ||
TypeSwitch<Operation *, Value>(extOp) | ||
.template Case<arith::ExtSIOp>([&](arith::ExtSIOp op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I don't think .template
is needed here.
: rewriter.create<arith::TruncIOp>(loc, extVt, shifted); | ||
return res->getResult(0); | ||
}) | ||
.template Case<arith::ExtUIOp>([&](arith::ExtUIOp op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you wanted to go full template here, you could do something like
template <typename ExtOp,
typename ShiftOp = std::conditional_t<std::is_same_v<ExtOp, arith::ExtUIOp>, arith::ShRUIOp, arith::ShRSIOp>>
static FailureOr<Value>
rewriteExtOfBitCastImpl (RewriterBase &rewriter, ExtOp op, ...) {
// ...
Value shifted = rewriter.template create<ShiftOp>(...);
// ...
}
return nullptr; | ||
}); | ||
|
||
if (maybeBroadcastOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: return early instead.
return rewriter.notifyMatchFailure(extOp, "not a vector type"); | ||
|
||
int64_t elementalBitWidth = resultTy.getElementTypeBitWidth(); | ||
if (elementalBitWidth & (elementalBitWidth - 1)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd rather use llvm::isPowerOf2_64
func.func @f1(%m: !mst, %idx : index, %mf: !mtt) { | ||
|
||
// CHECK: %[[MASK:.*]] = arith.constant dense<[ | ||
// CHECK-SAME-COUNT-6: 7, 56, 448, 3584, 28672, 229376, 1835008, 14680064, 117440512, 939524096, 7516192768, 60129542144, 481036337152, 3848290697216, 30786325577728, -35184372088832 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment with the hex version of this would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
+1 on @ftynse comments plus a few of my own :).
/// Create a vector of bit shuffles: `numOccurrences * idx` and broadcast it | ||
/// `bitwidth/step` times. | ||
/// `step` must divide `bitwidth` evenly. | ||
/// Example: (4, 2, 3) -> [0x0, 0x1, 0x0, 0x1, 0x0, 0x1]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a hard time reconciling the comment and the example.
For instance if we broadcast bitwidth/step
times, we should only have 2 reps, not 3.
@@ -29,33 +29,12 @@ transform.sequence failures(propagate) { | |||
// lowering TD macros. | |||
transform.apply_patterns to %f { | |||
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" | |||
} : !transform.any_op | |||
|
|||
transform.apply_patterns to %f { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes are unrelated, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the review is in good hands, so I'll just +1 having this. Thank you!
…(trunci) expansion This revision adds a rewrite for sequences of vector `bitcast(trunci)` to use a more efficient sequence of vector operations comprising `shuffle` and `bitwise` ops. Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect. The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance in the pre-trunci vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori` followed by an optional final `trunci`/`extui`. The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type is not an even number of bytes, further complexities arise that are not improved by this pattern: the heavy lifting still needs to be done by LLVM.
I need to rework this to reuse the same implementation as #66387 which is both more general and more performant overall. |
…cast) expansion This revision adds a rewrite for sequences of vector `ext(bitcast)` to use a more efficient sequence of vector operations comprising `shuffle` and `bitwise` ops. Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect. The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance in the source vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori` with shifts`. The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type is not an even number of bytes, further complexities arise that are not improved by this pattern: the heavy lifting still needs to be done by LLVM.
6f1997f
to
1ad1f91
Compare
This is now reimplemented in terms of #66387, closing this and starting another PR |
…cast) expansion
This revision adds a rewrite for sequences of vector
ext(maybe_broadcast(bitcast))
to use a more efficient sequence of vector operations comprising
shuffle
,shift
andbitwise
ops.The rewrite uses an intermediate bitwidth equal to the licm of
the element types of the source and result types of
bitCastOp
. Thisintermediate type may be small or greater than the desired elemental type of
the
ext
, in which case appropriateext
ortrunc
operations are inserted.The rewrite fails if the intermediate type is greater than
64
and if theinvolved vector types fail to meet basic divisilibity requirements. In other
words, this rewrite does not handle partial vector boundaries and leaves
this part of the heavy-lifting to LLVM.
In the future, it may be relevant to give control on the size of the intermediate type.
For now, it is empirically determined that taking
64
result in much better assemblybeing produced when piping through
llvm-mca
.