diff --git a/clang/lib/AST/Interp/Boolean.h b/clang/lib/AST/Interp/Boolean.h index 6f0fe26ace688..c3ed3d61f76ca 100644 --- a/clang/lib/AST/Interp/Boolean.h +++ b/clang/lib/AST/Interp/Boolean.h @@ -84,6 +84,12 @@ class Boolean final { Boolean truncate(unsigned TruncBits) const { return *this; } void print(llvm::raw_ostream &OS) const { OS << (V ? "true" : "false"); } + std::string toDiagnosticString(const ASTContext &Ctx) const { + std::string NameStr; + llvm::raw_string_ostream OS(NameStr); + print(OS); + return NameStr; + } static Boolean min(unsigned NumBits) { return Boolean(false); } static Boolean max(unsigned NumBits) { return Boolean(true); } diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp index e813d4fa651ce..a09e2a007b912 100644 --- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -253,6 +253,29 @@ bool ByteCodeExprGen::VisitBinaryOperator(const BinaryOperator *BO) { return this->delegate(RHS); } + // Special case for C++'s three-way/spaceship operator <=>, which + // returns a std::{strong,weak,partial}_ordering (which is a class, so doesn't + // have a PrimType). + if (!T) { + if (DiscardResult) + return true; + const ComparisonCategoryInfo *CmpInfo = + Ctx.getASTContext().CompCategories.lookupInfoForType(BO->getType()); + assert(CmpInfo); + + // We need a temporary variable holding our return value. + if (!Initializing) { + std::optional ResultIndex = this->allocateLocal(BO, false); + if (!this->emitGetPtrLocal(*ResultIndex, BO)) + return false; + } + + if (!visit(LHS) || !visit(RHS)) + return false; + + return this->emitCMP3(*LT, CmpInfo, BO); + } + if (!LT || !RT || !T) return this->bail(BO); diff --git a/clang/lib/AST/Interp/Floating.h b/clang/lib/AST/Interp/Floating.h index 9a8fd34ec9348..a22b3fa79f399 100644 --- a/clang/lib/AST/Interp/Floating.h +++ b/clang/lib/AST/Interp/Floating.h @@ -76,6 +76,12 @@ class Floating final { F.toString(Buffer); OS << Buffer; } + std::string toDiagnosticString(const ASTContext &Ctx) const { + std::string NameStr; + llvm::raw_string_ostream OS(NameStr); + print(OS); + return NameStr; + } unsigned bitWidth() const { return F.semanticsSizeInBits(F.getSemantics()); } diff --git a/clang/lib/AST/Interp/Integral.h b/clang/lib/AST/Interp/Integral.h index 72285cabcbbf8..0295a9c3b5c89 100644 --- a/clang/lib/AST/Interp/Integral.h +++ b/clang/lib/AST/Interp/Integral.h @@ -128,6 +128,13 @@ template class Integral final { return Compare(V, RHS.V); } + std::string toDiagnosticString(const ASTContext &Ctx) const { + std::string NameStr; + llvm::raw_string_ostream OS(NameStr); + OS << V; + return NameStr; + } + unsigned countLeadingZeros() const { if constexpr (!Signed) return llvm::countl_zero(V); diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h index 8453856e526a6..dd37150b63f6d 100644 --- a/clang/lib/AST/Interp/Interp.h +++ b/clang/lib/AST/Interp/Interp.h @@ -112,6 +112,11 @@ bool CheckCtorCall(InterpState &S, CodePtr OpPC, const Pointer &This); bool CheckPotentialReinterpretCast(InterpState &S, CodePtr OpPC, const Pointer &Ptr); +/// Sets the given integral value to the pointer, which is of +/// a std::{weak,partial,strong}_ordering type. +bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC, + const Pointer &Ptr, const APSInt &IntValue); + /// Checks if the shift operation is legal. template bool CheckShift(InterpState &S, CodePtr OpPC, const LT &LHS, const RT &RHS, @@ -781,6 +786,30 @@ bool EQ(InterpState &S, CodePtr OpPC) { }); } +template ::T> +bool CMP3(InterpState &S, CodePtr OpPC, const ComparisonCategoryInfo *CmpInfo) { + const T &RHS = S.Stk.pop(); + const T &LHS = S.Stk.pop(); + const Pointer &P = S.Stk.peek(); + + ComparisonCategoryResult CmpResult = LHS.compare(RHS); + if (CmpResult == ComparisonCategoryResult::Unordered) { + // This should only happen with pointers. + const SourceInfo &Loc = S.Current->getSource(OpPC); + S.FFDiag(Loc, diag::note_constexpr_pointer_comparison_unspecified) + << LHS.toDiagnosticString(S.getCtx()) + << RHS.toDiagnosticString(S.getCtx()); + return false; + } + + assert(CmpInfo); + const auto *CmpValueInfo = CmpInfo->getValueInfo(CmpResult); + assert(CmpValueInfo); + assert(CmpValueInfo->hasValidIntValue()); + APSInt IntValue = CmpValueInfo->getIntValue(); + return SetThreeWayComparisonField(S, OpPC, P, IntValue); +} + template ::T> bool NE(InterpState &S, CodePtr OpPC) { return CmpHelperEQ(S, OpPC, [](ComparisonCategoryResult R) { diff --git a/clang/lib/AST/Interp/InterpBuiltin.cpp b/clang/lib/AST/Interp/InterpBuiltin.cpp index 4536e335bf1a1..d816145598049 100644 --- a/clang/lib/AST/Interp/InterpBuiltin.cpp +++ b/clang/lib/AST/Interp/InterpBuiltin.cpp @@ -594,5 +594,22 @@ bool InterpretOffsetOf(InterpState &S, CodePtr OpPC, const OffsetOfExpr *E, return true; } +bool SetThreeWayComparisonField(InterpState &S, CodePtr OpPC, + const Pointer &Ptr, const APSInt &IntValue) { + + const Record *R = Ptr.getRecord(); + assert(R); + assert(R->getNumFields() == 1); + + unsigned FieldOffset = R->getField(0u)->Offset; + const Pointer &FieldPtr = Ptr.atField(FieldOffset); + PrimType FieldT = *S.getContext().classify(FieldPtr.getType()); + + INT_TYPE_SWITCH(FieldT, + FieldPtr.deref() = T::from(IntValue.getSExtValue())); + FieldPtr.initialize(); + return true; +} + } // namespace interp } // namespace clang diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td index eeb71db125fef..0ce64b769b01f 100644 --- a/clang/lib/AST/Interp/Opcodes.td +++ b/clang/lib/AST/Interp/Opcodes.td @@ -55,6 +55,7 @@ def ArgCastKind : ArgType { let Name = "CastKind"; } def ArgCallExpr : ArgType { let Name = "const CallExpr *"; } def ArgOffsetOfExpr : ArgType { let Name = "const OffsetOfExpr *"; } def ArgDeclRef : ArgType { let Name = "const DeclRefExpr *"; } +def ArgCCI : ArgType { let Name = "const ComparisonCategoryInfo *"; } //===----------------------------------------------------------------------===// // Classes of types instructions operate on. @@ -607,6 +608,10 @@ class ComparisonOpcode : Opcode { let HasGroup = 1; } +def CMP3 : ComparisonOpcode { + let Args = [ArgCCI]; +} + def LT : ComparisonOpcode; def LE : ComparisonOpcode; def GT : ComparisonOpcode; diff --git a/clang/lib/AST/Interp/Pointer.h b/clang/lib/AST/Interp/Pointer.h index 3834237f11d13..8c97a965320e1 100644 --- a/clang/lib/AST/Interp/Pointer.h +++ b/clang/lib/AST/Interp/Pointer.h @@ -362,6 +362,19 @@ class Pointer { /// Deactivates an entire strurcutre. void deactivate() const; + /// Compare two pointers. + ComparisonCategoryResult compare(const Pointer &Other) const { + if (!hasSameBase(*this, Other)) + return ComparisonCategoryResult::Unordered; + + if (Offset < Other.Offset) + return ComparisonCategoryResult::Less; + else if (Offset > Other.Offset) + return ComparisonCategoryResult::Greater; + + return ComparisonCategoryResult::Equal; + } + /// Checks if two pointers are comparable. static bool hasSameBase(const Pointer &A, const Pointer &B); /// Checks if two pointers can be subtracted. diff --git a/clang/test/AST/Interp/cxx20.cpp b/clang/test/AST/Interp/cxx20.cpp index df08bb75199d8..0b13f41270a95 100644 --- a/clang/test/AST/Interp/cxx20.cpp +++ b/clang/test/AST/Interp/cxx20.cpp @@ -646,3 +646,57 @@ namespace ImplicitFunction { // expected-error {{not an integral constant expression}} \ // expected-note {{in call to 'callMe()'}} } + +/// FIXME: Unfortunately, the similar tests in test/SemaCXX/{compare-cxx2a.cpp use member pointers, +/// which we don't support yet. +namespace std { + class strong_ordering { + public: + int n; + static const strong_ordering less, equal, greater; + constexpr bool operator==(int n) const noexcept { return this->n == n;} + constexpr bool operator!=(int n) const noexcept { return this->n != n;} + }; + constexpr strong_ordering strong_ordering::less = {-1}; + constexpr strong_ordering strong_ordering::equal = {0}; + constexpr strong_ordering strong_ordering::greater = {1}; + + class partial_ordering { + public: + long n; + static const partial_ordering less, equal, greater, equivalent, unordered; + constexpr bool operator==(long n) const noexcept { return this->n == n;} + constexpr bool operator!=(long n) const noexcept { return this->n != n;} + }; + constexpr partial_ordering partial_ordering::less = {-1}; + constexpr partial_ordering partial_ordering::equal = {0}; + constexpr partial_ordering partial_ordering::greater = {1}; + constexpr partial_ordering partial_ordering::equivalent = {0}; + constexpr partial_ordering partial_ordering::unordered = {-127}; +} // namespace std + +namespace ThreeWayCmp { + static_assert(1 <=> 2 == -1, ""); + static_assert(1 <=> 1 == 0, ""); + static_assert(2 <=> 1 == 1, ""); + static_assert(1.0 <=> 2.f == -1, ""); + static_assert(1.0 <=> 1.0 == 0, ""); + static_assert(2.0 <=> 1.0 == 1, ""); + constexpr int k = (1 <=> 1, 0); // expected-warning {{comparison result unused}} \ + // ref-warning {{comparison result unused}} + static_assert(k== 0, ""); + + /// Pointers. + constexpr int a[] = {1,2,3}; + constexpr int b[] = {1,2,3}; + constexpr const int *pa1 = &a[1]; + constexpr const int *pa2 = &a[2]; + constexpr const int *pb1 = &b[1]; + static_assert(pa1 <=> pb1 != 0, ""); // expected-error {{not an integral constant expression}} \ + // expected-note {{has unspecified value}} \ + // ref-error {{not an integral constant expression}} \ + // ref-note {{has unspecified value}} + static_assert(pa1 <=> pa1 == 0, ""); + static_assert(pa1 <=> pa2 == -1, ""); + static_assert(pa2 <=> pa1 == 1, ""); +}