Skip to content

Commit

Permalink
[ValueTracking] Support vscale in computeConstantRange()
Browse files Browse the repository at this point in the history
Add support for vscale in computeConstantRange(), based on
vscale_range attributes. This allows simplifying based on the
precise range, rather than a KnownBits approximation (which will
be off by a factor of two for the usual case of a power of two
upper bound).

Differential Revision: https://reviews.llvm.org/D146217
  • Loading branch information
nikic committed Mar 17, 2023
1 parent a8f6b57 commit 402dfa3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 42 deletions.
48 changes: 25 additions & 23 deletions llvm/lib/Analysis/ValueTracking.cpp
Expand Up @@ -1152,6 +1152,25 @@ KnownBits llvm::analyzeKnownBitsFromAndXorOr(
Query(DL, AC, safeCxtI(I, CxtI), DT, UseInstrInfo, ORE));
}

static ConstantRange getVScaleRange(const Function *F, unsigned BitWidth) {
Attribute Attr = F->getFnAttribute(Attribute::VScaleRange);
// Without vscale_range, we only know that vscale is non-zero.
if (!Attr.isValid())
return ConstantRange(APInt(BitWidth, 1), APInt::getZero(BitWidth));

unsigned AttrMin = Attr.getVScaleRangeMin();
// Minimum is larger than vscale width, result is always poison.
if ((unsigned)llvm::bit_width(AttrMin) > BitWidth)
return ConstantRange::getEmpty(BitWidth);

APInt Min(BitWidth, AttrMin);
std::optional<unsigned> AttrMax = Attr.getVScaleRangeMax();
if (!AttrMax || (unsigned)llvm::bit_width(*AttrMax) > BitWidth)
return ConstantRange(Min, APInt::getZero(BitWidth));

return ConstantRange(Min, APInt(BitWidth, *AttrMax) + 1);
}

static void computeKnownBitsFromOperator(const Operator *I,
const APInt &DemandedElts,
KnownBits &Known, unsigned Depth,
Expand Down Expand Up @@ -1820,31 +1839,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
Known.Zero.setBitsFrom(17);
break;
case Intrinsic::vscale: {
if (!II->getParent() || !II->getFunction() ||
!II->getFunction()->hasFnAttribute(Attribute::VScaleRange))
break;

auto Attr = II->getFunction()->getFnAttribute(Attribute::VScaleRange);
std::optional<unsigned> VScaleMax = Attr.getVScaleRangeMax();

if (!VScaleMax)
if (!II->getParent() || !II->getFunction())
break;

unsigned VScaleMin = Attr.getVScaleRangeMin();

// If vscale min = max then we know the exact value at compile time
// and hence we know the exact bits.
if (VScaleMin == VScaleMax) {
Known.One = VScaleMin;
Known.Zero = VScaleMin;
Known.Zero.flipAllBits();
break;
}

unsigned FirstZeroHighBit = llvm::bit_width(*VScaleMax);
if (FirstZeroHighBit < BitWidth)
Known.Zero.setBitsFrom(FirstZeroHighBit);

Known = getVScaleRange(II->getFunction(), BitWidth).toKnownBits();
break;
}
}
Expand Down Expand Up @@ -7773,6 +7771,10 @@ static ConstantRange getRangeForIntrinsic(const IntrinsicInst &II) {

return ConstantRange(APInt::getZero(Width),
APInt::getSignedMinValue(Width) + 1);
case Intrinsic::vscale:
if (!II.getParent() || !II.getFunction())
break;
return getVScaleRange(II.getFunction(), Width);
default:
break;
}
Expand Down
26 changes: 7 additions & 19 deletions llvm/test/Transforms/InstCombine/icmp-vscale.ll
Expand Up @@ -84,9 +84,7 @@ entry:

define i1 @vscale_ule_max() vscale_range(5,10) {
; CHECK-LABEL: @vscale_ule_max(
; CHECK-NEXT: [[VSCALE:%.*]] = call i16 @llvm.vscale.i16()
; CHECK-NEXT: [[RES:%.*]] = icmp ult i16 [[VSCALE]], 11
; CHECK-NEXT: ret i1 [[RES]]
; CHECK-NEXT: ret i1 true
;
%vscale = call i16 @llvm.vscale.i16()
%res = icmp ule i16 %vscale, 10
Expand All @@ -106,9 +104,7 @@ define i1 @vscale_ult_max() vscale_range(5,10) {

define i1 @vscale_uge_min() vscale_range(5,10) {
; CHECK-LABEL: @vscale_uge_min(
; CHECK-NEXT: [[VSCALE:%.*]] = call i16 @llvm.vscale.i16()
; CHECK-NEXT: [[RES:%.*]] = icmp ugt i16 [[VSCALE]], 4
; CHECK-NEXT: ret i1 [[RES]]
; CHECK-NEXT: ret i1 true
;
%vscale = call i16 @llvm.vscale.i16()
%res = icmp uge i16 %vscale, 5
Expand Down Expand Up @@ -146,9 +142,7 @@ define i1 @vscale_ugt_no_max() vscale_range(5) {

define i1 @vscale_uge_max_overflow() vscale_range(5,256) {
; CHECK-LABEL: @vscale_uge_max_overflow(
; CHECK-NEXT: [[VSCALE:%.*]] = call i8 @llvm.vscale.i8()
; CHECK-NEXT: [[RES:%.*]] = icmp ugt i8 [[VSCALE]], 4
; CHECK-NEXT: ret i1 [[RES]]
; CHECK-NEXT: ret i1 true
;
%vscale = call i8 @llvm.vscale.i8()
%res = icmp uge i8 %vscale, 5
Expand All @@ -168,9 +162,7 @@ define i1 @vscale_ugt_max_overflow() vscale_range(5,256) {

define i1 @vscale_eq_min_overflow() vscale_range(256,300) {
; CHECK-LABEL: @vscale_eq_min_overflow(
; CHECK-NEXT: [[VSCALE:%.*]] = call i8 @llvm.vscale.i8()
; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[VSCALE]], 42
; CHECK-NEXT: ret i1 [[RES]]
; CHECK-NEXT: ret i1 true
;
%vscale = call i8 @llvm.vscale.i8()
%res = icmp eq i8 %vscale, 42
Expand All @@ -179,9 +171,7 @@ define i1 @vscale_eq_min_overflow() vscale_range(256,300) {

define i1 @vscale_ult_min_overflow() vscale_range(256,300) {
; CHECK-LABEL: @vscale_ult_min_overflow(
; CHECK-NEXT: [[VSCALE:%.*]] = call i8 @llvm.vscale.i8()
; CHECK-NEXT: [[RES:%.*]] = icmp ult i8 [[VSCALE]], 42
; CHECK-NEXT: ret i1 [[RES]]
; CHECK-NEXT: ret i1 true
;
%vscale = call i8 @llvm.vscale.i8()
%res = icmp ult i8 %vscale, 42
Expand All @@ -190,12 +180,10 @@ define i1 @vscale_ult_min_overflow() vscale_range(256,300) {

define i1 @vscale_ugt_min_overflow() vscale_range(256,300) {
; CHECK-LABEL: @vscale_ugt_min_overflow(
; CHECK-NEXT: [[VSCALE:%.*]] = call i8 @llvm.vscale.i8()
; CHECK-NEXT: [[RES:%.*]] = icmp ult i8 [[VSCALE]], 42
; CHECK-NEXT: ret i1 [[RES]]
; CHECK-NEXT: ret i1 true
;
%vscale = call i8 @llvm.vscale.i8()
%res = icmp ult i8 %vscale, 42
%res = icmp ugt i8 %vscale, 42
ret i1 %res
}

Expand Down

0 comments on commit 402dfa3

Please sign in to comment.