Skip to content

Commit

Permalink
[clang][Interp] Three-way comparisons (#65901)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbaederr committed Sep 29, 2023
1 parent 7cc83c5 commit 512739e
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 0 deletions.
6 changes: 6 additions & 0 deletions clang/lib/AST/Interp/Boolean.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ class Boolean final {
Boolean truncate(unsigned TruncBits) const { return *this; }

void print(llvm::raw_ostream &OS) const { OS << (V ? "true" : "false"); }
std::string toDiagnosticString(const ASTContext &Ctx) const {
std::string NameStr;
llvm::raw_string_ostream OS(NameStr);
print(OS);
return NameStr;
}

static Boolean min(unsigned NumBits) { return Boolean(false); }
static Boolean max(unsigned NumBits) { return Boolean(true); }
Expand Down
23 changes: 23 additions & 0 deletions clang/lib/AST/Interp/ByteCodeExprGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,29 @@ bool ByteCodeExprGen<Emitter>::VisitBinaryOperator(const BinaryOperator *BO) {
return this->delegate(RHS);
}

// Special case for C++'s three-way/spaceship operator <=>, which
// returns a std::{strong,weak,partial}_ordering (which is a class, so doesn't
// have a PrimType).
if (!T) {
if (DiscardResult)
return true;
const ComparisonCategoryInfo *CmpInfo =
Ctx.getASTContext().CompCategories.lookupInfoForType(BO->getType());
assert(CmpInfo);

// We need a temporary variable holding our return value.
if (!Initializing) {
std::optional<unsigned> ResultIndex = this->allocateLocal(BO, false);
if (!this->emitGetPtrLocal(*ResultIndex, BO))
return false;
}

if (!visit(LHS) || !visit(RHS))
return false;

return this->emitCMP3(*LT, CmpInfo, BO);
}

if (!LT || !RT || !T)
return this->bail(BO);

Expand Down
6 changes: 6 additions & 0 deletions clang/lib/AST/Interp/Floating.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ class Floating final {
F.toString(Buffer);
OS << Buffer;
}
std::string toDiagnosticString(const ASTContext &Ctx) const {
std::string NameStr;
llvm::raw_string_ostream OS(NameStr);
print(OS);
return NameStr;
}

unsigned bitWidth() const { return F.semanticsSizeInBits(F.getSemantics()); }

Expand Down
7 changes: 7 additions & 0 deletions clang/lib/AST/Interp/Integral.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ template <unsigned Bits, bool Signed> class Integral final {
return Compare(V, RHS.V);
}

std::string toDiagnosticString(const ASTContext &Ctx) const {
std::string NameStr;
llvm::raw_string_ostream OS(NameStr);
OS << V;
return NameStr;
}

unsigned countLeadingZeros() const {
if constexpr (!Signed)
return llvm::countl_zero<ReprT>(V);
Expand Down
29 changes: 29 additions & 0 deletions clang/lib/AST/Interp/Interp.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ bool CheckCtorCall(InterpState &S, CodePtr OpPC, const Pointer &This);
bool CheckPotentialReinterpretCast(InterpState &S, CodePtr OpPC,
const Pointer &Ptr);

/// Sets the given integral value to the pointer, which is of
/// a std::{weak,partial,strong}_ordering type.
bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC,
const Pointer &Ptr, const APSInt &IntValue);

/// Checks if the shift operation is legal.
template <typename LT, typename RT>
bool CheckShift(InterpState &S, CodePtr OpPC, const LT &LHS, const RT &RHS,
Expand Down Expand Up @@ -781,6 +786,30 @@ bool EQ(InterpState &S, CodePtr OpPC) {
});
}

template <PrimType Name, class T = typename PrimConv<Name>::T>
bool CMP3(InterpState &S, CodePtr OpPC, const ComparisonCategoryInfo *CmpInfo) {
const T &RHS = S.Stk.pop<T>();
const T &LHS = S.Stk.pop<T>();
const Pointer &P = S.Stk.peek<Pointer>();

ComparisonCategoryResult CmpResult = LHS.compare(RHS);
if (CmpResult == ComparisonCategoryResult::Unordered) {
// This should only happen with pointers.
const SourceInfo &Loc = S.Current->getSource(OpPC);
S.FFDiag(Loc, diag::note_constexpr_pointer_comparison_unspecified)
<< LHS.toDiagnosticString(S.getCtx())
<< RHS.toDiagnosticString(S.getCtx());
return false;
}

assert(CmpInfo);
const auto *CmpValueInfo = CmpInfo->getValueInfo(CmpResult);
assert(CmpValueInfo);
assert(CmpValueInfo->hasValidIntValue());
APSInt IntValue = CmpValueInfo->getIntValue();
return SetThreeWayComparisonField(S, OpPC, P, IntValue);
}

template <PrimType Name, class T = typename PrimConv<Name>::T>
bool NE(InterpState &S, CodePtr OpPC) {
return CmpHelperEQ<T>(S, OpPC, [](ComparisonCategoryResult R) {
Expand Down
17 changes: 17 additions & 0 deletions clang/lib/AST/Interp/InterpBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,5 +594,22 @@ bool InterpretOffsetOf(InterpState &S, CodePtr OpPC, const OffsetOfExpr *E,
return true;
}

bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC,
const Pointer &Ptr, const APSInt &IntValue) {

const Record *R = Ptr.getRecord();
assert(R);
assert(R->getNumFields() == 1);

unsigned FieldOffset = R->getField(0u)->Offset;
const Pointer &FieldPtr = Ptr.atField(FieldOffset);
PrimType FieldT = *S.getContext().classify(FieldPtr.getType());

INT_TYPE_SWITCH(FieldT,
FieldPtr.deref<T>() = T::from(IntValue.getSExtValue()));
FieldPtr.initialize();
return true;
}

} // namespace interp
} // namespace clang
5 changes: 5 additions & 0 deletions clang/lib/AST/Interp/Opcodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def ArgCastKind : ArgType { let Name = "CastKind"; }
def ArgCallExpr : ArgType { let Name = "const CallExpr *"; }
def ArgOffsetOfExpr : ArgType { let Name = "const OffsetOfExpr *"; }
def ArgDeclRef : ArgType { let Name = "const DeclRefExpr *"; }
def ArgCCI : ArgType { let Name = "const ComparisonCategoryInfo *"; }

//===----------------------------------------------------------------------===//
// Classes of types instructions operate on.
Expand Down Expand Up @@ -607,6 +608,10 @@ class ComparisonOpcode : Opcode {
let HasGroup = 1;
}

def CMP3 : ComparisonOpcode {
let Args = [ArgCCI];
}

def LT : ComparisonOpcode;
def LE : ComparisonOpcode;
def GT : ComparisonOpcode;
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/AST/Interp/Pointer.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,19 @@ class Pointer {
/// Deactivates an entire strurcutre.
void deactivate() const;

/// Compare two pointers.
ComparisonCategoryResult compare(const Pointer &Other) const {
if (!hasSameBase(*this, Other))
return ComparisonCategoryResult::Unordered;

if (Offset < Other.Offset)
return ComparisonCategoryResult::Less;
else if (Offset > Other.Offset)
return ComparisonCategoryResult::Greater;

return ComparisonCategoryResult::Equal;
}

/// Checks if two pointers are comparable.
static bool hasSameBase(const Pointer &A, const Pointer &B);
/// Checks if two pointers can be subtracted.
Expand Down
54 changes: 54 additions & 0 deletions clang/test/AST/Interp/cxx20.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,3 +646,57 @@ namespace ImplicitFunction {
// expected-error {{not an integral constant expression}} \
// expected-note {{in call to 'callMe()'}}
}

/// FIXME: Unfortunately, the similar tests in test/SemaCXX/{compare-cxx2a.cpp use member pointers,
/// which we don't support yet.
namespace std {
class strong_ordering {
public:
int n;
static const strong_ordering less, equal, greater;
constexpr bool operator==(int n) const noexcept { return this->n == n;}
constexpr bool operator!=(int n) const noexcept { return this->n != n;}
};
constexpr strong_ordering strong_ordering::less = {-1};
constexpr strong_ordering strong_ordering::equal = {0};
constexpr strong_ordering strong_ordering::greater = {1};

class partial_ordering {
public:
long n;
static const partial_ordering less, equal, greater, equivalent, unordered;
constexpr bool operator==(long n) const noexcept { return this->n == n;}
constexpr bool operator!=(long n) const noexcept { return this->n != n;}
};
constexpr partial_ordering partial_ordering::less = {-1};
constexpr partial_ordering partial_ordering::equal = {0};
constexpr partial_ordering partial_ordering::greater = {1};
constexpr partial_ordering partial_ordering::equivalent = {0};
constexpr partial_ordering partial_ordering::unordered = {-127};
} // namespace std

namespace ThreeWayCmp {
static_assert(1 <=> 2 == -1, "");
static_assert(1 <=> 1 == 0, "");
static_assert(2 <=> 1 == 1, "");
static_assert(1.0 <=> 2.f == -1, "");
static_assert(1.0 <=> 1.0 == 0, "");
static_assert(2.0 <=> 1.0 == 1, "");
constexpr int k = (1 <=> 1, 0); // expected-warning {{comparison result unused}} \
// ref-warning {{comparison result unused}}
static_assert(k== 0, "");

/// Pointers.
constexpr int a[] = {1,2,3};
constexpr int b[] = {1,2,3};
constexpr const int *pa1 = &a[1];
constexpr const int *pa2 = &a[2];
constexpr const int *pb1 = &b[1];
static_assert(pa1 <=> pb1 != 0, ""); // expected-error {{not an integral constant expression}} \
// expected-note {{has unspecified value}} \
// ref-error {{not an integral constant expression}} \
// ref-note {{has unspecified value}}
static_assert(pa1 <=> pa1 == 0, "");
static_assert(pa1 <=> pa2 == -1, "");
static_assert(pa2 <=> pa1 == 1, "");
}

0 comments on commit 512739e

Please sign in to comment.