diff --git a/clang/lib/AST/ByteCode/Compiler.cpp b/clang/lib/AST/ByteCode/Compiler.cpp index 6c088469a3ca2..18e6bbcc82bd3 100644 --- a/clang/lib/AST/ByteCode/Compiler.cpp +++ b/clang/lib/AST/ByteCode/Compiler.cpp @@ -208,6 +208,19 @@ template class LocOverrideScope final { } // namespace interp } // namespace clang +template +bool Compiler::isValidBitCast(const CastExpr *E) { + QualType FromTy = E->getSubExpr()->getType()->getPointeeType(); + QualType ToTy = E->getType()->getPointeeType(); + + if (classify(FromTy) == classify(ToTy)) + return true; + + if (FromTy->isVoidType() || ToTy->isVoidType()) + return true; + return false; +} + template bool Compiler::VisitCastExpr(const CastExpr *CE) { const Expr *SubExpr = CE->getSubExpr(); @@ -476,8 +489,9 @@ bool Compiler::VisitCastExpr(const CastExpr *CE) { return this->delegate(SubExpr); case CK_BitCast: { + QualType CETy = CE->getType(); // Reject bitcasts to atomic types. - if (CE->getType()->isAtomicType()) { + if (CETy->isAtomicType()) { if (!this->discard(SubExpr)) return false; return this->emitInvalidCast(CastKind::Reinterpret, /*Fatal=*/true, CE); @@ -492,6 +506,12 @@ bool Compiler::VisitCastExpr(const CastExpr *CE) { if (!FromT || !ToT) return false; + if (!this->isValidBitCast(CE)) { + if (!this->emitInvalidCast(CastKind::ReinterpretLike, /*Fatal=*/false, + CE)) + return false; + } + assert(isPtrType(*FromT)); assert(isPtrType(*ToT)); if (FromT == ToT) { diff --git a/clang/lib/AST/ByteCode/Compiler.h b/clang/lib/AST/ByteCode/Compiler.h index 5c46f75af4da3..fac0a7f4e1886 100644 --- a/clang/lib/AST/ByteCode/Compiler.h +++ b/clang/lib/AST/ByteCode/Compiler.h @@ -425,6 +425,8 @@ class Compiler : public ConstStmtVisitor, bool>, bool refersToUnion(const Expr *E); + bool isValidBitCast(const CastExpr *E); + protected: /// Variable to storage mapping. llvm::DenseMap Locals; diff --git a/clang/lib/AST/ByteCode/Interp.h b/clang/lib/AST/ByteCode/Interp.h index 89f6fbefb1907..ee020ee08f3eb 100644 --- a/clang/lib/AST/ByteCode/Interp.h +++ b/clang/lib/AST/ByteCode/Interp.h @@ -1914,6 +1914,10 @@ bool Load(InterpState &S, CodePtr OpPC) { return false; if (!Ptr.isBlockPointer()) return false; + if (!(Ptr.getFieldDesc()->isPrimitive() || + Ptr.getFieldDesc()->isPrimitiveArray()) || + Ptr.getFieldDesc()->getPrimType() != Name) + return false; S.Stk.push(Ptr.deref()); return true; } @@ -1925,6 +1929,10 @@ bool LoadPop(InterpState &S, CodePtr OpPC) { return false; if (!Ptr.isBlockPointer()) return false; + if (!(Ptr.getFieldDesc()->isPrimitive() || + Ptr.getFieldDesc()->isPrimitiveArray()) || + Ptr.getFieldDesc()->getPrimType() != Name) + return false; S.Stk.push(Ptr.deref()); return true; } @@ -3286,12 +3294,18 @@ inline bool InvalidCast(InterpState &S, CodePtr OpPC, CastKind Kind, bool Fatal) { const SourceLocation &Loc = S.Current->getLocation(OpPC); - if (Kind == CastKind::Reinterpret) { + switch (Kind) { + case CastKind::Reinterpret: S.CCEDiag(Loc, diag::note_constexpr_invalid_cast) - << static_cast(Kind) << S.Current->getRange(OpPC); + << diag::ConstexprInvalidCastKind::Reinterpret + << S.Current->getRange(OpPC); return !Fatal; - } - if (Kind == CastKind::Volatile) { + case CastKind::ReinterpretLike: + S.CCEDiag(Loc, diag::note_constexpr_invalid_cast) + << diag::ConstexprInvalidCastKind::ThisConversionOrReinterpret + << S.getLangOpts().CPlusPlus << S.Current->getRange(OpPC); + return !Fatal; + case CastKind::Volatile: if (!S.checkingPotentialConstantExpression()) { const auto *E = cast(S.Current->getExpr(OpPC)); if (S.getLangOpts().CPlusPlus) @@ -3302,14 +3316,13 @@ inline bool InvalidCast(InterpState &S, CodePtr OpPC, CastKind Kind, } return false; - } - if (Kind == CastKind::Dynamic) { + case CastKind::Dynamic: assert(!S.getLangOpts().CPlusPlus20); - S.CCEDiag(S.Current->getSource(OpPC), diag::note_constexpr_invalid_cast) + S.CCEDiag(Loc, diag::note_constexpr_invalid_cast) << diag::ConstexprInvalidCastKind::Dynamic; return true; } - + llvm_unreachable("Unhandled CastKind"); return false; } diff --git a/clang/lib/AST/ByteCode/PrimType.h b/clang/lib/AST/ByteCode/PrimType.h index 54fd39ac6fcc8..f0454b484ff98 100644 --- a/clang/lib/AST/ByteCode/PrimType.h +++ b/clang/lib/AST/ByteCode/PrimType.h @@ -101,6 +101,7 @@ inline constexpr bool isSignedType(PrimType T) { enum class CastKind : uint8_t { Reinterpret, + ReinterpretLike, Volatile, Dynamic, }; @@ -111,6 +112,9 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, case interp::CastKind::Reinterpret: OS << "reinterpret_cast"; break; + case interp::CastKind::ReinterpretLike: + OS << "reinterpret_like"; + break; case interp::CastKind::Volatile: OS << "volatile"; break; diff --git a/clang/lib/AST/ByteCode/Program.cpp b/clang/lib/AST/ByteCode/Program.cpp index e0b2852f0e906..465f4d15def73 100644 --- a/clang/lib/AST/ByteCode/Program.cpp +++ b/clang/lib/AST/ByteCode/Program.cpp @@ -36,30 +36,19 @@ unsigned Program::createGlobalString(const StringLiteral *S, const Expr *Base) { const size_t BitWidth = CharWidth * Ctx.getCharBit(); unsigned StringLength = S->getLength(); - PrimType CharType; - switch (CharWidth) { - case 1: - CharType = PT_Sint8; - break; - case 2: - CharType = PT_Uint16; - break; - case 4: - CharType = PT_Uint32; - break; - default: - llvm_unreachable("unsupported character width"); - } + OptPrimType CharType = + Ctx.classify(S->getType()->castAsArrayTypeUnsafe()->getElementType()); + assert(CharType); if (!Base) Base = S; // Create a descriptor for the string. - Descriptor *Desc = - allocateDescriptor(Base, CharType, Descriptor::GlobalMD, StringLength + 1, - /*isConst=*/true, - /*isTemporary=*/false, - /*isMutable=*/false); + Descriptor *Desc = allocateDescriptor(Base, *CharType, Descriptor::GlobalMD, + StringLength + 1, + /*isConst=*/true, + /*isTemporary=*/false, + /*isMutable=*/false); // Allocate storage for the string. // The byte length does not include the null terminator. @@ -79,26 +68,9 @@ unsigned Program::createGlobalString(const StringLiteral *S, const Expr *Base) { } else { // Construct the string in storage. for (unsigned I = 0; I <= StringLength; ++I) { - const uint32_t CodePoint = I == StringLength ? 0 : S->getCodeUnit(I); - switch (CharType) { - case PT_Sint8: { - using T = PrimConv::T; - Ptr.elem(I) = T::from(CodePoint, BitWidth); - break; - } - case PT_Uint16: { - using T = PrimConv::T; - Ptr.elem(I) = T::from(CodePoint, BitWidth); - break; - } - case PT_Uint32: { - using T = PrimConv::T; - Ptr.elem(I) = T::from(CodePoint, BitWidth); - break; - } - default: - llvm_unreachable("unsupported character type"); - } + uint32_t CodePoint = I == StringLength ? 0 : S->getCodeUnit(I); + INT_TYPE_SWITCH_NO_BOOL(*CharType, + Ptr.elem(I) = T::from(CodePoint, BitWidth);); } } Ptr.initializeAllElements(); diff --git a/clang/test/AST/ByteCode/invalid.cpp b/clang/test/AST/ByteCode/invalid.cpp index 00db27419e36b..1f2d6bc1d48eb 100644 --- a/clang/test/AST/ByteCode/invalid.cpp +++ b/clang/test/AST/ByteCode/invalid.cpp @@ -66,3 +66,26 @@ struct S { S s; S *sp[2] = {&s, &s}; S *&spp = sp[1]; + +namespace InvalidBitCast { + void foo() { + const long long int i = 1; // both-note {{declared const here}} + if (*(double *)&i == 2) { + i = 0; // both-error {{cannot assign to variable}} + } + } + + struct S2 { + void *p; + }; + struct T { + S2 s; + }; + constexpr T t = {{nullptr}}; + constexpr void *foo2() { return ((void **)&t)[0]; } // both-error {{never produces a constant expression}} \ + // both-note 2{{cast that performs the conversions of a reinterpret_cast}} + constexpr auto x = foo2(); // both-error {{must be initialized by a constant expression}} \ + // both-note {{in call to}} + + +}