Skip to content

Commit

Permalink
[clang][Interp] BaseToDerived casts
Browse files Browse the repository at this point in the history
We can implement these similarly to DerivedToBase casts. We just have to
walk the class hierarchy, sum the base offsets and subtract it from the
current base offset of the pointer.

Differential Revision: https://reviews.llvm.org/D149133
  • Loading branch information
tbaederr committed Sep 5, 2023
1 parent f1246e9 commit 12a7897
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 12 deletions.
35 changes: 23 additions & 12 deletions clang/lib/AST/Interp/ByteCodeExprGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,20 @@ bool ByteCodeExprGen<Emitter>::VisitCastExpr(const CastExpr *CE) {
if (!this->visit(SubExpr))
return false;

return this->emitDerivedToBaseCasts(getRecordTy(SubExpr->getType()),
getRecordTy(CE->getType()), CE);
unsigned DerivedOffset = collectBaseOffset(getRecordTy(CE->getType()),
getRecordTy(SubExpr->getType()));

return this->emitGetPtrBasePop(DerivedOffset, CE);
}

case CK_BaseToDerived: {
if (!this->visit(SubExpr))
return false;

unsigned DerivedOffset = collectBaseOffset(getRecordTy(SubExpr->getType()),
getRecordTy(CE->getType()));

return this->emitGetPtrDerivedPop(DerivedOffset, CE);
}

case CK_FloatingCast: {
Expand Down Expand Up @@ -2262,35 +2274,34 @@ void ByteCodeExprGen<Emitter>::emitCleanup() {
}

template <class Emitter>
bool ByteCodeExprGen<Emitter>::emitDerivedToBaseCasts(
const RecordType *DerivedType, const RecordType *BaseType, const Expr *E) {
// Pointer of derived type is already on the stack.
unsigned
ByteCodeExprGen<Emitter>::collectBaseOffset(const RecordType *BaseType,
const RecordType *DerivedType) {
const auto *FinalDecl = cast<CXXRecordDecl>(BaseType->getDecl());
const RecordDecl *CurDecl = DerivedType->getDecl();
const Record *CurRecord = getRecord(CurDecl);
assert(CurDecl && FinalDecl);

unsigned OffsetSum = 0;
for (;;) {
assert(CurRecord->getNumBases() > 0);
// One level up
for (const Record::Base &B : CurRecord->bases()) {
const auto *BaseDecl = cast<CXXRecordDecl>(B.Decl);

if (BaseDecl == FinalDecl || BaseDecl->isDerivedFrom(FinalDecl)) {
// This decl will lead us to the final decl, so emit a base cast.
if (!this->emitGetPtrBasePop(B.Offset, E))
return false;

OffsetSum += B.Offset;
CurRecord = B.R;
CurDecl = BaseDecl;
break;
}
}
if (CurDecl == FinalDecl)
return true;
break;
}

llvm_unreachable("Couldn't find the base class?");
return false;
assert(OffsetSum > 0);
return OffsetSum;
}

/// When calling this, we have a pointer of the local-to-destroy
Expand Down
2 changes: 2 additions & 0 deletions clang/lib/AST/Interp/ByteCodeExprGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ class ByteCodeExprGen : public ConstStmtVisitor<ByteCodeExprGen<Emitter>, bool>,
bool emitRecordDestruction(const Descriptor *Desc);
bool emitDerivedToBaseCasts(const RecordType *DerivedType,
const RecordType *BaseType, const Expr *E);
unsigned collectBaseOffset(const RecordType *BaseType,
const RecordType *DerivedType);

protected:
/// Variable to storage mapping.
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/AST/Interp/Interp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,16 @@ bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
return false;
}

bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
CheckSubobjectKind CSK) {
if (!Ptr.isOnePastEnd())
return true;

const SourceInfo &Loc = S.Current->getSource(OpPC);
S.FFDiag(Loc, diag::note_constexpr_past_end_subobject) << CSK;
return false;
}

bool CheckConst(InterpState &S, CodePtr OpPC, const Pointer &Ptr) {
assert(Ptr.isLive() && "Pointer is not live");
if (!Ptr.isConst())
Expand Down
18 changes: 18 additions & 0 deletions clang/lib/AST/Interp/Interp.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
CheckSubobjectKind CSK);

/// Checks if accessing a base or derived record of the given pointer is valid.
bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr,
CheckSubobjectKind CSK);

/// Checks if a pointer points to const storage.
bool CheckConst(InterpState &S, CodePtr OpPC, const Pointer &Ptr);

Expand Down Expand Up @@ -1157,10 +1161,22 @@ inline bool GetPtrActiveThisField(InterpState &S, CodePtr OpPC, uint32_t Off) {
return true;
}

inline bool GetPtrDerivedPop(InterpState &S, CodePtr OpPC, uint32_t Off) {
const Pointer &Ptr = S.Stk.pop<Pointer>();
if (!CheckNull(S, OpPC, Ptr, CSK_Derived))
return false;
if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Derived))
return false;
S.Stk.push<Pointer>(Ptr.atFieldSub(Off));
return true;
}

inline bool GetPtrBase(InterpState &S, CodePtr OpPC, uint32_t Off) {
const Pointer &Ptr = S.Stk.peek<Pointer>();
if (!CheckNull(S, OpPC, Ptr, CSK_Base))
return false;
if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base))
return false;
S.Stk.push<Pointer>(Ptr.atField(Off));
return true;
}
Expand All @@ -1169,6 +1185,8 @@ inline bool GetPtrBasePop(InterpState &S, CodePtr OpPC, uint32_t Off) {
const Pointer &Ptr = S.Stk.pop<Pointer>();
if (!CheckNull(S, OpPC, Ptr, CSK_Base))
return false;
if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base))
return false;
S.Stk.push<Pointer>(Ptr.atField(Off));
return true;
}
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/AST/Interp/Opcodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,10 @@ def GetPtrBasePop : Opcode {
let Args = [ArgUint32];
}

def GetPtrDerivedPop : Opcode {
let Args = [ArgUint32];
}

// [Pointer] -> [Pointer]
def GetPtrVirtBase : Opcode {
// RecordDecl of base class.
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/AST/Interp/Pointer.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,14 @@ class Pointer {
return Pointer(Pointee, Field, Field);
}

/// Subtract the given offset from the current Base and Offset
/// of the pointer.
Pointer atFieldSub(unsigned Off) const {
assert(Offset >= Off);
unsigned O = Offset - Off;
return Pointer(Pointee, O, O);
}

/// Restricts the scope of an array element pointer.
Pointer narrow() const {
// Null pointers cannot be narrowed.
Expand Down
52 changes: 52 additions & 0 deletions clang/test/AST/Interp/records.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,58 @@ namespace Destructors {
// ref-note {{in call to 'testS()'}}
}

namespace BaseToDerived {
namespace A {
struct A {};
struct B : A { int n; };
struct C : B {};
C c = {};
constexpr C *pb = (C*)((A*)&c + 1); // expected-error {{must be initialized by a constant expression}} \
// expected-note {{cannot access derived class of pointer past the end of object}} \
// ref-error {{must be initialized by a constant expression}} \
// ref-note {{cannot access derived class of pointer past the end of object}}
}
namespace B {
struct A {};
struct Z {};
struct B : Z, A {
int n;
constexpr B() : n(10) {}
};
struct C : B {
constexpr C() : B() {}
};

constexpr C c = {};
constexpr const A *pa = &c;
constexpr const C *cp = (C*)pa;
constexpr const B *cb = (B*)cp;

static_assert(cb->n == 10);
static_assert(cp->n == 10);
}

namespace C {
struct Base { int *a; };
struct Base2 : Base { int f[12]; };

struct Middle1 { int b[3]; };
struct Middle2 : Base2 { char c; };
struct Middle3 : Middle2 { char g[3]; };
struct Middle4 { int f[3]; };
struct Middle5 : Middle4, Middle3 { char g2[3]; };

struct NotQuiteDerived : Middle1, Middle5 { bool d; };
struct Derived : NotQuiteDerived { int e; };

constexpr NotQuiteDerived NQD1 = {};

constexpr Middle5 *M4 = (Middle5*)((Base2*)&NQD1);
static_assert(M4->a == nullptr);
static_assert(M4->g2[0] == 0);
}
}


namespace VirtualDtors {
class A {
Expand Down

0 comments on commit 12a7897

Please sign in to comment.