Skip to content

Commit

Permalink
[ValueTracking] Handle range attributes (#85143)
Browse files Browse the repository at this point in the history
Handle the range attribute in ValueTracking.
  • Loading branch information
andjo403 committed Mar 20, 2024
1 parent f24d68a commit e66cfeb
Show file tree
Hide file tree
Showing 9 changed files with 481 additions and 14 deletions.
7 changes: 7 additions & 0 deletions llvm/include/llvm/IR/Argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Value.h"
#include <optional>

namespace llvm {

class ConstantRange;

/// This class represents an incoming formal argument to a Function. A formal
/// argument, since it is ``formal'', does not contain an actual value but
/// instead represents the type, argument number, and attributes of an argument
Expand Down Expand Up @@ -67,6 +70,10 @@ class Argument final : public Value {
/// disallowed floating-point values. Otherwise, fcNone is returned.
FPClassTest getNoFPClass() const;

/// If this argument has a range attribute, return the value range of the
/// argument. Otherwise, std::nullopt is returned.
std::optional<ConstantRange> getRange() const;

/// Return true if this argument has the byval attribute.
bool hasByValAttr() const;

Expand Down
7 changes: 6 additions & 1 deletion llvm/include/llvm/IR/InstrTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace llvm {
class StringRef;
class Type;
class Value;
class ConstantRange;

namespace Intrinsic {
typedef unsigned ID;
Expand Down Expand Up @@ -1917,7 +1918,7 @@ class CallBase : public Instruction {

// Look at the callee, if available.
if (const Function *F = getCalledFunction())
return F->getAttributes().getRetAttr(Kind);
return F->getRetAttribute(Kind);
return Attribute();
}

Expand Down Expand Up @@ -2154,6 +2155,10 @@ class CallBase : public Instruction {
/// parameter.
FPClassTest getParamNoFPClass(unsigned i) const;

/// If this return value has a range attribute, return the value range of the
/// argument. Otherwise, std::nullopt is returned.
std::optional<ConstantRange> getRange() const;

/// Return true if the return value is known to be not null.
/// This may be because it has the nonnull attribute, or because at least
/// one byte is dereferenceable and the pointer is in addrspace(0).
Expand Down
13 changes: 4 additions & 9 deletions llvm/lib/Analysis/InstructionSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3736,15 +3736,10 @@ static std::optional<ConstantRange> getRange(Value *V,
if (MDNode *MD = IIQ.getMetadata(I, LLVMContext::MD_range))
return getConstantRangeFromMetadata(*MD);

Attribute Range;
if (const Argument *A = dyn_cast<Argument>(V)) {
Range = A->getAttribute(llvm::Attribute::Range);
} else if (const CallBase *CB = dyn_cast<CallBase>(V)) {
Range = CB->getRetAttr(llvm::Attribute::Range);
}

if (Range.isValid())
return Range.getRange();
if (const Argument *A = dyn_cast<Argument>(V))
return A->getRange();
else if (const CallBase *CB = dyn_cast<CallBase>(V))
return CB->getRange();

return std::nullopt;
}
Expand Down
38 changes: 34 additions & 4 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1500,14 +1500,20 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Instruction::Call:
case Instruction::Invoke:
case Instruction::Invoke: {
// If range metadata is attached to this call, set known bits from that,
// and then intersect with known bits based on other properties of the
// function.
if (MDNode *MD =
Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range))
computeKnownBitsFromRangeMetadata(*MD, Known);
if (const Value *RV = cast<CallBase>(I)->getReturnedArgOperand()) {

const auto *CB = cast<CallBase>(I);

if (std::optional<ConstantRange> Range = CB->getRange())
Known = Known.unionWith(Range->toKnownBits());

if (const Value *RV = CB->getReturnedArgOperand()) {
if (RV->getType() == I->getType()) {
computeKnownBits(RV, Known2, Depth + 1, Q);
Known = Known.unionWith(Known2);
Expand Down Expand Up @@ -1679,6 +1685,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
}
}
break;
}
case Instruction::ShuffleVector: {
auto *Shuf = dyn_cast<ShuffleVectorInst>(I);
// FIXME: Do we need to handle ConstantExpr involving shufflevectors?
Expand Down Expand Up @@ -1933,6 +1940,10 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
// assumptions. Confirm that we've handled them all.
assert(!isa<ConstantData>(V) && "Unhandled constant data!");

if (const auto *A = dyn_cast<Argument>(V))
if (std::optional<ConstantRange> Range = A->getRange())
Known = Range->toKnownBits();

// All recursive calls that increase depth must come after this.
if (Depth == MaxAnalysisRecursionDepth)
return;
Expand Down Expand Up @@ -2783,6 +2794,11 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
} else {
if (MDNode *Ranges = Q.IIQ.getMetadata(Call, LLVMContext::MD_range))
return rangeMetadataExcludesValue(Ranges, APInt::getZero(BitWidth));
if (std::optional<ConstantRange> Range = Call->getRange()) {
const APInt ZeroValue(Range->getBitWidth(), 0);
if (!Range->contains(ZeroValue))
return true;
}
if (const Value *RV = Call->getReturnedArgOperand())
if (RV->getType() == I->getType() && isKnownNonZero(RV, Depth, Q))
return true;
Expand Down Expand Up @@ -2921,6 +2937,13 @@ bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth,
return false;
}

if (const auto *A = dyn_cast<Argument>(V))
if (std::optional<ConstantRange> Range = A->getRange()) {
const APInt ZeroValue(Range->getBitWidth(), 0);
if (!Range->contains(ZeroValue))
return true;
}

if (!isa<Constant>(V) && isKnownNonZeroFromAssume(V, Q))
return true;

Expand Down Expand Up @@ -9146,12 +9169,19 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
// TODO: Return ConstantRange.
setLimitForFPToI(cast<Instruction>(V), Lower, Upper);
CR = ConstantRange::getNonEmpty(Lower, Upper);
}
} else if (const auto *A = dyn_cast<Argument>(V))
if (std::optional<ConstantRange> Range = A->getRange())
CR = *Range;

if (auto *I = dyn_cast<Instruction>(V))
if (auto *I = dyn_cast<Instruction>(V)) {
if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range))
CR = CR.intersectWith(getConstantRangeFromMetadata(*Range));

if (const auto *CB = dyn_cast<CallBase>(V))
if (std::optional<ConstantRange> Range = CB->getRange())
CR = CR.intersectWith(*Range);
}

if (CtxI && AC) {
// Try to restrict the range based on information from assumptions.
for (auto &AssumeVH : AC->assumptionsFor(V)) {
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/IR/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/GlobalValue.h"
Expand Down Expand Up @@ -256,6 +257,13 @@ FPClassTest Argument::getNoFPClass() const {
return getParent()->getParamNoFPClass(getArgNo());
}

std::optional<ConstantRange> Argument::getRange() const {
const Attribute RangeAttr = getAttribute(llvm::Attribute::Range);
if (RangeAttr.isValid())
return RangeAttr.getRange();
return std::nullopt;
}

bool Argument::hasNestAttr() const {
if (!getType()->isPointerTy()) return false;
return hasAttribute(Attribute::Nest);
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/IR/Instructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DerivedTypes.h"
Expand Down Expand Up @@ -395,6 +396,13 @@ FPClassTest CallBase::getParamNoFPClass(unsigned i) const {
return Mask;
}

std::optional<ConstantRange> CallBase::getRange() const {
const Attribute RangeAttr = getRetAttr(llvm::Attribute::Range);
if (RangeAttr.isValid())
return RangeAttr.getRange();
return std::nullopt;
}

bool CallBase::isReturnNonNull() const {
if (hasRetAttr(Attribute::NonNull))
return true;
Expand Down
147 changes: 147 additions & 0 deletions llvm/test/Analysis/ValueTracking/known-non-zero.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1303,4 +1303,151 @@ define <2 x i1> @range_metadata_vec(ptr %p, <2 x i32> %x) {
ret <2 x i1> %cmp
}

define i1 @range_attr(i8 range(i8 1, 0) %x, i8 %y) {
; CHECK-LABEL: @range_attr(
; CHECK-NEXT: ret i1 false
;
%or = or i8 %y, %x
%cmp = icmp eq i8 %or, 0
ret i1 %cmp
}

define i1 @neg_range_attr(i8 range(i8 -1, 1) %x, i8 %y) {
; CHECK-LABEL: @neg_range_attr(
; CHECK-NEXT: [[I:%.*]] = or i8 [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[I]], 0
; CHECK-NEXT: ret i1 [[CMP]]
;
%or = or i8 %y, %x
%cmp = icmp eq i8 %or, 0
ret i1 %cmp
}

declare range(i8 1, 0) i8 @returns_non_zero_range_helper()
declare range(i8 -1, 1) i8 @returns_contain_zero_range_helper()

define i1 @range_return(i8 %y) {
; CHECK-LABEL: @range_return(
; CHECK-NEXT: [[I:%.*]] = call i8 @returns_non_zero_range_helper()
; CHECK-NEXT: ret i1 false
;
%x = call i8 @returns_non_zero_range_helper()
%or = or i8 %y, %x
%cmp = icmp eq i8 %or, 0
ret i1 %cmp
}

define i1 @neg_range_return(i8 %y) {
; CHECK-LABEL: @neg_range_return(
; CHECK-NEXT: [[I:%.*]] = call i8 @returns_contain_zero_range_helper()
; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], [[I]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[OR]], 0
; CHECK-NEXT: ret i1 [[CMP]]
;
%x = call i8 @returns_contain_zero_range_helper()
%or = or i8 %y, %x
%cmp = icmp eq i8 %or, 0
ret i1 %cmp
}

declare i8 @returns_i8_helper()

define i1 @range_call(i8 %y) {
; CHECK-LABEL: @range_call(
; CHECK-NEXT: [[I:%.*]] = call range(i8 1, 0) i8 @returns_i8_helper()
; CHECK-NEXT: ret i1 false
;
%x = call range(i8 1, 0) i8 @returns_i8_helper()
%or = or i8 %y, %x
%cmp = icmp eq i8 %or, 0
ret i1 %cmp
}

define i1 @neg_range_call(i8 %y) {
; CHECK-LABEL: @neg_range_call(
; CHECK-NEXT: [[I:%.*]] = call range(i8 -1, 1) i8 @returns_i8_helper()
; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], [[I]]
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[OR]], 0
; CHECK-NEXT: ret i1 [[CMP]]
;
%x = call range(i8 -1, 1) i8 @returns_i8_helper()
%or = or i8 %y, %x
%cmp = icmp eq i8 %or, 0
ret i1 %cmp
}

define <2 x i1> @range_attr_vec(<2 x i8> range(i8 1, 0) %x, <2 x i8> %y) {
; CHECK-LABEL: @range_attr_vec(
; CHECK-NEXT: ret <2 x i1> <i1 true, i1 true>
;
%or = or <2 x i8> %y, %x
%cmp = icmp ne <2 x i8> %or, zeroinitializer
ret <2 x i1> %cmp
}

define <2 x i1> @neg_range_attr_vec(<2 x i8> range(i8 -1, 1) %x, <2 x i8> %y) {
; CHECK-LABEL: @neg_range_attr_vec(
; CHECK-NEXT: [[OR:%.*]] = or <2 x i8> [[Y:%.*]], [[X:%.*]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[OR]], zeroinitializer
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%or = or <2 x i8> %y, %x
%cmp = icmp ne <2 x i8> %or, zeroinitializer
ret <2 x i1> %cmp
}

declare range(i8 1, 0) <2 x i8> @returns_non_zero_range_helper_vec()
declare range(i8 -1, 1) <2 x i8> @returns_contain_zero_range_helper_vec()

define <2 x i1> @range_return_vec(<2 x i8> %y) {
; CHECK-LABEL: @range_return_vec(
; CHECK-NEXT: [[I:%.*]] = call <2 x i8> @returns_non_zero_range_helper_vec()
; CHECK-NEXT: ret <2 x i1> <i1 true, i1 true>
;
%x = call <2 x i8> @returns_non_zero_range_helper_vec()
%or = or <2 x i8> %y, %x
%cmp = icmp ne <2 x i8> %or, zeroinitializer
ret <2 x i1> %cmp
}

define <2 x i1> @neg_range_return_vec(<2 x i8> %y) {
; CHECK-LABEL: @neg_range_return_vec(
; CHECK-NEXT: [[I:%.*]] = call <2 x i8> @returns_contain_zero_range_helper_vec()
; CHECK-NEXT: [[OR:%.*]] = or <2 x i8> [[Y:%.*]], [[I]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[OR]], zeroinitializer
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%x = call <2 x i8> @returns_contain_zero_range_helper_vec()
%or = or <2 x i8> %y, %x
%cmp = icmp ne <2 x i8> %or, zeroinitializer
ret <2 x i1> %cmp
}

declare <2 x i8> @returns_i8_helper_vec()

define <2 x i1> @range_call_vec(<2 x i8> %y) {
; CHECK-LABEL: @range_call_vec(
; CHECK-NEXT: [[I:%.*]] = call range(i8 1, 0) <2 x i8> @returns_i8_helper_vec()
; CHECK-NEXT: ret <2 x i1> <i1 true, i1 true>
;
%x = call range(i8 1, 0) <2 x i8> @returns_i8_helper_vec()
%or = or <2 x i8> %y, %x
%cmp = icmp ne <2 x i8> %or, zeroinitializer
ret <2 x i1> %cmp
}

define <2 x i1> @neg_range_call_vec(<2 x i8> %y) {
; CHECK-LABEL: @neg_range_call_vec(
; CHECK-NEXT: [[I:%.*]] = call range(i8 -1, 1) <2 x i8> @returns_i8_helper_vec()
; CHECK-NEXT: [[OR:%.*]] = or <2 x i8> [[Y:%.*]], [[I]]
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[OR]], zeroinitializer
; CHECK-NEXT: ret <2 x i1> [[CMP]]
;
%x = call range(i8 -1, 1) <2 x i8> @returns_i8_helper_vec()
%or = or <2 x i8> %y, %x
%cmp = icmp ne <2 x i8> %or, zeroinitializer
ret <2 x i1> %cmp
}


declare i32 @llvm.experimental.get.vector.length.i32(i32, i32, i1)
Loading

0 comments on commit e66cfeb

Please sign in to comment.