Skip to content

Commit

Permalink
Optimize emission of dynamic_cast to final classes.
Browse files Browse the repository at this point in the history
- When the destination is a final class type that does not derive from
  the source type, the cast always fails and is now emitted as a null
  pointer or call to __cxa_bad_cast.

- When the destination is a final class type that does derive from the
  source type, emit a direct comparison against the corresponding base
  class vptr value(s). There may be more than one such value in the case
  of multiple inheritance; check them all.

For now, this is supported only for the Itanium ABI. I expect the same thing is
possible for the MS ABI too, but I don't know what guarantees are made about
vfptr uniqueness.

Reviewed By: rjmccall

Differential Revision: https://reviews.llvm.org/D154658
  • Loading branch information
zygoloid committed Jul 22, 2023
1 parent 57bd882 commit 9d525bf
Show file tree
Hide file tree
Showing 14 changed files with 355 additions and 38 deletions.
1 change: 1 addition & 0 deletions clang/include/clang/Basic/CodeGenOptions.def
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ CODEGENOPT(Dwarf64 , 1, 0) ///< -gdwarf64.
CODEGENOPT(Dwarf32 , 1, 1) ///< -gdwarf32.
CODEGENOPT(PreserveAsmComments, 1, 1) ///< -dA, -fno-preserve-as-comments.
CODEGENOPT(AssumeSaneOperatorNew , 1, 1) ///< implicit __attribute__((malloc)) operator new
CODEGENOPT(AssumeUniqueVTables , 1, 1) ///< Assume a class has only one vtable.
CODEGENOPT(Autolink , 1, 1) ///< -fno-autolink
CODEGENOPT(ObjCAutoRefCountExceptions , 1, 0) ///< Whether ARC should be EH-safe.
CODEGENOPT(Backchain , 1, 0) ///< -mbackchain
Expand Down
7 changes: 7 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,13 @@ def static_libsan : Flag<["-"], "static-libsan">,
def : Flag<["-"], "shared-libasan">, Alias<shared_libsan>;
def fasm : Flag<["-"], "fasm">, Group<f_Group>;

defm assume_unique_vtables : BoolFOption<"assume-unique-vtables",
CodeGenOpts<"AssumeUniqueVTables">, DefaultTrue,
PosFlag<SetTrue>,
NegFlag<SetFalse, [CC1Option],
"Disable optimizations based on vtable pointer identity">,
BothFlags<[CoreOption]>>;

def fassume_sane_operator_new : Flag<["-"], "fassume-sane-operator-new">, Group<f_Group>;
def fastcp : Flag<["-"], "fastcp">, Group<f_Group>;
def fastf : Flag<["-"], "fastf">, Group<f_Group>;
Expand Down
36 changes: 21 additions & 15 deletions clang/lib/AST/ExprCXX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -767,29 +767,35 @@ CXXDynamicCastExpr *CXXDynamicCastExpr::CreateEmpty(const ASTContext &C,
/// struct C { };
///
/// C *f(B* b) { return dynamic_cast<C*>(b); }
bool CXXDynamicCastExpr::isAlwaysNull() const
{
bool CXXDynamicCastExpr::isAlwaysNull() const {
if (isValueDependent() || getCastKind() != CK_Dynamic)
return false;

QualType SrcType = getSubExpr()->getType();
QualType DestType = getType();

if (const auto *SrcPTy = SrcType->getAs<PointerType>()) {
SrcType = SrcPTy->getPointeeType();
DestType = DestType->castAs<PointerType>()->getPointeeType();
}

if (DestType->isVoidType())
if (DestType->isVoidPointerType())
return false;

const auto *SrcRD =
cast<CXXRecordDecl>(SrcType->castAs<RecordType>()->getDecl());
if (DestType->isPointerType()) {
SrcType = SrcType->getPointeeType();
DestType = DestType->getPointeeType();
}

if (!SrcRD->hasAttr<FinalAttr>())
return false;
const auto *SrcRD = SrcType->getAsCXXRecordDecl();
const auto *DestRD = DestType->getAsCXXRecordDecl();
assert(SrcRD && DestRD);

const auto *DestRD =
cast<CXXRecordDecl>(DestType->castAs<RecordType>()->getDecl());
if (SrcRD->isEffectivelyFinal()) {
assert(!SrcRD->isDerivedFrom(DestRD) &&
"upcasts should not use CK_Dynamic");
return true;
}

return !DestRD->isDerivedFrom(SrcRD);
if (DestRD->isEffectivelyFinal() && !DestRD->isDerivedFrom(SrcRD))
return true;

return false;
}

CXXReinterpretCastExpr *
Expand Down
10 changes: 10 additions & 0 deletions clang/lib/CodeGen/CGCXXABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ class CGCXXABI {

virtual bool shouldDynamicCastCallBeNullChecked(bool SrcIsPtr,
QualType SrcRecordTy) = 0;
virtual bool shouldEmitExactDynamicCast(QualType DestRecordTy) = 0;

virtual llvm::Value *emitDynamicCastCall(CodeGenFunction &CGF, Address Value,
QualType SrcRecordTy,
Expand All @@ -298,6 +299,15 @@ class CGCXXABI {
Address Value,
QualType SrcRecordTy) = 0;

/// Emit a dynamic_cast from SrcRecordTy to DestRecordTy. The cast fails if
/// the dynamic type of Value is not exactly DestRecordTy.
virtual llvm::Value *emitExactDynamicCast(CodeGenFunction &CGF, Address Value,
QualType SrcRecordTy,
QualType DestTy,
QualType DestRecordTy,
llvm::BasicBlock *CastSuccess,
llvm::BasicBlock *CastFail) = 0;

virtual bool EmitBadCastCall(CodeGenFunction &CGF) = 0;

virtual llvm::Value *GetVirtualBaseClassOffset(CodeGenFunction &CGF,
Expand Down
51 changes: 35 additions & 16 deletions clang/lib/CodeGen/CGExprCXX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2226,8 +2226,8 @@ static llvm::Value *EmitDynamicCastToNull(CodeGenFunction &CGF,
if (!CGF.CGM.getCXXABI().EmitBadCastCall(CGF))
return nullptr;

CGF.EmitBlock(CGF.createBasicBlock("dynamic_cast.end"));
return llvm::UndefValue::get(DestLTy);
CGF.Builder.ClearInsertionPoint();
return llvm::PoisonValue::get(DestLTy);
}

llvm::Value *CodeGenFunction::EmitDynamicCast(Address ThisAddr,
Expand All @@ -2240,17 +2240,16 @@ llvm::Value *CodeGenFunction::EmitDynamicCast(Address ThisAddr,
// C++ [expr.dynamic.cast]p7:
// If T is "pointer to cv void," then the result is a pointer to the most
// derived object pointed to by v.
const PointerType *DestPTy = DestTy->getAs<PointerType>();

bool isDynamicCastToVoid;
bool IsDynamicCastToVoid = DestTy->isVoidPointerType();
QualType SrcRecordTy;
QualType DestRecordTy;
if (DestPTy) {
isDynamicCastToVoid = DestPTy->getPointeeType()->isVoidType();
if (IsDynamicCastToVoid) {
SrcRecordTy = SrcTy->getPointeeType();
// No DestRecordTy.
} else if (const PointerType *DestPTy = DestTy->getAs<PointerType>()) {
SrcRecordTy = SrcTy->castAs<PointerType>()->getPointeeType();
DestRecordTy = DestPTy->getPointeeType();
} else {
isDynamicCastToVoid = false;
SrcRecordTy = SrcTy;
DestRecordTy = DestTy->castAs<ReferenceType>()->getPointeeType();
}
Expand All @@ -2263,18 +2262,29 @@ llvm::Value *CodeGenFunction::EmitDynamicCast(Address ThisAddr,
EmitTypeCheck(TCK_DynamicOperation, DCE->getExprLoc(), ThisAddr.getPointer(),
SrcRecordTy);

if (DCE->isAlwaysNull())
if (llvm::Value *T = EmitDynamicCastToNull(*this, DestTy))
if (DCE->isAlwaysNull()) {
if (llvm::Value *T = EmitDynamicCastToNull(*this, DestTy)) {
// Expression emission is expected to retain a valid insertion point.
if (!Builder.GetInsertBlock())
EmitBlock(createBasicBlock("dynamic_cast.unreachable"));
return T;
}
}

assert(SrcRecordTy->isRecordType() && "source type must be a record type!");

// If the destination is effectively final, the cast succeeds if and only
// if the dynamic type of the pointer is exactly the destination type.
bool IsExact = !IsDynamicCastToVoid &&
DestRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() &&
CGM.getCXXABI().shouldEmitExactDynamicCast(DestRecordTy);

// C++ [expr.dynamic.cast]p4:
// If the value of v is a null pointer value in the pointer case, the result
// is the null pointer value of type T.
bool ShouldNullCheckSrcValue =
CGM.getCXXABI().shouldDynamicCastCallBeNullChecked(SrcTy->isPointerType(),
SrcRecordTy);
IsExact || CGM.getCXXABI().shouldDynamicCastCallBeNullChecked(
SrcTy->isPointerType(), SrcRecordTy);

llvm::BasicBlock *CastNull = nullptr;
llvm::BasicBlock *CastNotNull = nullptr;
Expand All @@ -2290,29 +2300,38 @@ llvm::Value *CodeGenFunction::EmitDynamicCast(Address ThisAddr,
}

llvm::Value *Value;
if (isDynamicCastToVoid) {
if (IsDynamicCastToVoid) {
Value = CGM.getCXXABI().emitDynamicCastToVoid(*this, ThisAddr, SrcRecordTy);
} else if (IsExact) {
// If the destination type is effectively final, this pointer points to the
// right type if and only if its vptr has the right value.
Value = CGM.getCXXABI().emitExactDynamicCast(
*this, ThisAddr, SrcRecordTy, DestTy, DestRecordTy, CastEnd, CastNull);
} else {
assert(DestRecordTy->isRecordType() &&
"destination type must be a record type!");
Value = CGM.getCXXABI().emitDynamicCastCall(*this, ThisAddr, SrcRecordTy,
DestTy, DestRecordTy, CastEnd);
CastNotNull = Builder.GetInsertBlock();
}
CastNotNull = Builder.GetInsertBlock();

llvm::Value *NullValue = nullptr;
if (ShouldNullCheckSrcValue) {
EmitBranch(CastEnd);

EmitBlock(CastNull);
NullValue = EmitDynamicCastToNull(*this, DestTy);
CastNull = Builder.GetInsertBlock();

EmitBranch(CastEnd);
}

EmitBlock(CastEnd);

if (ShouldNullCheckSrcValue) {
if (CastNull) {
llvm::PHINode *PHI = Builder.CreatePHI(Value->getType(), 2);
PHI->addIncoming(Value, CastNotNull);
PHI->addIncoming(llvm::Constant::getNullValue(Value->getType()), CastNull);
PHI->addIncoming(NullValue, CastNull);

Value = PHI;
}
Expand Down
4 changes: 1 addition & 3 deletions clang/lib/CodeGen/CodeGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7152,9 +7152,7 @@ llvm::Constant *CodeGenModule::GetAddrOfRTTIDescriptor(QualType Ty,
// Return a bogus pointer if RTTI is disabled, unless it's for EH.
// FIXME: should we even be calling this method if RTTI is disabled
// and it's not for EH?
if ((!ForEH && !getLangOpts().RTTI) || getLangOpts().CUDAIsDevice ||
(getLangOpts().OpenMP && getLangOpts().OpenMPIsTargetDevice &&
getTriple().isNVPTX()))
if (!shouldEmitRTTI(ForEH))
return llvm::Constant::getNullValue(GlobalsInt8PtrTy);

if (ForEH && Ty->isObjCObjectPointerType() &&
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,13 @@ class CodeGenModule : public CodeGenTypeCache {
// Return the function body address of the given function.
llvm::Constant *GetFunctionStart(const ValueDecl *Decl);

// Return whether RTTI information should be emitted for this target.
bool shouldEmitRTTI(bool ForEH = false) {
return (ForEH || getLangOpts().RTTI) && !getLangOpts().CUDAIsDevice &&
!(getLangOpts().OpenMP && getLangOpts().OpenMPIsTargetDevice &&
getTriple().isNVPTX());
}

/// Get the address of the RTTI descriptor for the given type.
llvm::Constant *GetAddrOfRTTIDescriptor(QualType Ty, bool ForEH = false);

Expand Down
128 changes: 127 additions & 1 deletion clang/lib/CodeGen/ItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include "llvm/IR/Value.h"
#include "llvm/Support/ScopedPrinter.h"

#include <optional>

using namespace clang;
using namespace CodeGen;

Expand Down Expand Up @@ -185,11 +187,56 @@ class ItaniumCXXABI : public CodeGen::CGCXXABI {
bool shouldDynamicCastCallBeNullChecked(bool SrcIsPtr,
QualType SrcRecordTy) override;

/// Determine whether we know that all instances of type RecordTy will have
/// the same vtable pointer values, that is distinct from all other vtable
/// pointers. While this is required by the Itanium ABI, it doesn't happen in
/// practice in some cases due to language extensions.
bool hasUniqueVTablePointer(QualType RecordTy) {
const CXXRecordDecl *RD = RecordTy->getAsCXXRecordDecl();

// Under -fapple-kext, multiple definitions of the same vtable may be
// emitted.
if (!CGM.getCodeGenOpts().AssumeUniqueVTables ||
getContext().getLangOpts().AppleKext)
return false;

// If the type_info* would be null, the vtable might be merged with that of
// another type.
if (!CGM.shouldEmitRTTI())
return false;

// If there's only one definition of the vtable in the program, it has a
// unique address.
if (!llvm::GlobalValue::isWeakForLinker(CGM.getVTableLinkage(RD)))
return true;

// Even if there are multiple definitions of the vtable, they are required
// by the ABI to use the same symbol name, so should be merged at load
// time. However, if the class has hidden visibility, there can be
// different versions of the class in different modules, and the ABI
// library might treat them as being the same.
if (CGM.GetLLVMVisibility(RD->getVisibility()) !=
llvm::GlobalValue::DefaultVisibility)
return false;

return true;
}

bool shouldEmitExactDynamicCast(QualType DestRecordTy) override {
return hasUniqueVTablePointer(DestRecordTy);
}

llvm::Value *emitDynamicCastCall(CodeGenFunction &CGF, Address Value,
QualType SrcRecordTy, QualType DestTy,
QualType DestRecordTy,
llvm::BasicBlock *CastEnd) override;

llvm::Value *emitExactDynamicCast(CodeGenFunction &CGF, Address ThisAddr,
QualType SrcRecordTy, QualType DestTy,
QualType DestRecordTy,
llvm::BasicBlock *CastSuccess,
llvm::BasicBlock *CastFail) override;

llvm::Value *emitDynamicCastToVoid(CodeGenFunction &CGF, Address Value,
QualType SrcRecordTy) override;

Expand Down Expand Up @@ -1202,7 +1249,8 @@ void ItaniumCXXABI::emitVirtualObjectDelete(CodeGenFunction &CGF,
// Track back to entry -2 and pull out the offset there.
llvm::Value *OffsetPtr = CGF.Builder.CreateConstInBoundsGEP1_64(
CGF.IntPtrTy, VTable, -2, "complete-offset.ptr");
llvm::Value *Offset = CGF.Builder.CreateAlignedLoad(CGF.IntPtrTy, OffsetPtr, CGF.getPointerAlign());
llvm::Value *Offset = CGF.Builder.CreateAlignedLoad(CGF.IntPtrTy, OffsetPtr,
CGF.getPointerAlign());

// Apply the offset.
llvm::Value *CompletePtr =
Expand Down Expand Up @@ -1463,6 +1511,84 @@ llvm::Value *ItaniumCXXABI::emitDynamicCastCall(
return Value;
}

llvm::Value *ItaniumCXXABI::emitExactDynamicCast(
CodeGenFunction &CGF, Address ThisAddr, QualType SrcRecordTy,
QualType DestTy, QualType DestRecordTy, llvm::BasicBlock *CastSuccess,
llvm::BasicBlock *CastFail) {
ASTContext &Context = getContext();

// Find all the inheritance paths.
const CXXRecordDecl *SrcDecl = SrcRecordTy->getAsCXXRecordDecl();
const CXXRecordDecl *DestDecl = DestRecordTy->getAsCXXRecordDecl();
CXXBasePaths Paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true,
/*DetectVirtual=*/false);
(void)DestDecl->isDerivedFrom(SrcDecl, Paths);

// Find an offset within `DestDecl` where a `SrcDecl` instance and its vptr
// might appear.
std::optional<CharUnits> Offset;
for (const CXXBasePath &Path : Paths) {
// dynamic_cast only finds public inheritance paths.
if (Path.Access != AS_public)
continue;

CharUnits PathOffset;
for (const CXXBasePathElement &PathElement : Path) {
// Find the offset along this inheritance step.
const CXXRecordDecl *Base =
PathElement.Base->getType()->getAsCXXRecordDecl();
if (PathElement.Base->isVirtual()) {
// For a virtual base class, we know that the derived class is exactly
// DestDecl, so we can use the vbase offset from its layout.
const ASTRecordLayout &L = Context.getASTRecordLayout(DestDecl);
PathOffset = L.getVBaseClassOffset(Base);
} else {
const ASTRecordLayout &L =
Context.getASTRecordLayout(PathElement.Class);
PathOffset += L.getBaseClassOffset(Base);
}
}

if (!Offset)
Offset = PathOffset;
else if (Offset != PathOffset) {
// Base appears in at least two different places. Find the most-derived
// object and see if it's a DestDecl. Note that the most-derived object
// must be at least as aligned as this base class subobject, and must
// have a vptr at offset 0.
ThisAddr = Address(emitDynamicCastToVoid(CGF, ThisAddr, SrcRecordTy),
CGF.VoidPtrTy, ThisAddr.getAlignment());
SrcDecl = DestDecl;
Offset = CharUnits::Zero();
break;
}
}

if (!Offset) {
// If there are no public inheritance paths, the cast always fails.
CGF.EmitBranch(CastFail);
return llvm::PoisonValue::get(CGF.VoidPtrTy);
}

// Compare the vptr against the expected vptr for the destination type at
// this offset. Note that we do not know what type ThisAddr points to in
// the case where the derived class multiply inherits from the base class
// so we can't use GetVTablePtr, so we load the vptr directly instead.
llvm::Instruction *VPtr = CGF.Builder.CreateLoad(
ThisAddr.withElementType(CGF.VoidPtrPtrTy), "vtable");
CGM.DecorateInstructionWithTBAA(
VPtr, CGM.getTBAAVTablePtrAccessInfo(CGF.VoidPtrPtrTy));
llvm::Value *Success = CGF.Builder.CreateICmpEQ(
VPtr, getVTableAddressPoint(BaseSubobject(SrcDecl, *Offset), DestDecl));
llvm::Value *Result = ThisAddr.getPointer();
if (!Offset->isZero())
Result = CGF.Builder.CreateInBoundsGEP(
CGF.CharTy, Result,
{llvm::ConstantInt::get(CGF.PtrDiffTy, -Offset->getQuantity())});
CGF.Builder.CreateCondBr(Success, CastSuccess, CastFail);
return Result;
}

llvm::Value *ItaniumCXXABI::emitDynamicCastToVoid(CodeGenFunction &CGF,
Address ThisAddr,
QualType SrcRecordTy) {
Expand Down

0 comments on commit 9d525bf

Please sign in to comment.