Skip to content

Commit

Permalink
[clang][dataflow] Make optional checker work for types derived from o…
Browse files Browse the repository at this point in the history
…ptional. (#84138)

`llvm::MaybeAlign` does this, for example.

It's not an option to simply ignore these derived classes because they
get cast
back to the optional classes (for example, simply when calling the
optional
member functions), and our transfer functions will then run on those
optional
classes and therefore require them to be properly initialized.
  • Loading branch information
martinboehme committed Mar 19, 2024
1 parent 2f2f16f commit d712c5e
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 54 deletions.
188 changes: 134 additions & 54 deletions clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,39 +64,125 @@ static bool hasOptionalClassName(const CXXRecordDecl &RD) {
return false;
}

static const CXXRecordDecl *getOptionalBaseClass(const CXXRecordDecl *RD) {
if (RD == nullptr)
return nullptr;
if (hasOptionalClassName(*RD))
return RD;

if (!RD->hasDefinition())
return nullptr;

for (const CXXBaseSpecifier &Base : RD->bases())
if (const CXXRecordDecl *BaseClass =
getOptionalBaseClass(Base.getType()->getAsCXXRecordDecl()))
return BaseClass;

return nullptr;
}

namespace {

using namespace ::clang::ast_matchers;
using LatticeTransferState = TransferState<NoopLattice>;

AST_MATCHER(CXXRecordDecl, hasOptionalClassNameMatcher) {
return hasOptionalClassName(Node);
AST_MATCHER(CXXRecordDecl, optionalClass) { return hasOptionalClassName(Node); }

AST_MATCHER(CXXRecordDecl, optionalOrDerivedClass) {
return getOptionalBaseClass(&Node) != nullptr;
}

DeclarationMatcher optionalClass() {
return classTemplateSpecializationDecl(
hasOptionalClassNameMatcher(),
hasTemplateArgument(0, refersToType(type().bind("T"))));
auto desugarsToOptionalType() {
return hasUnqualifiedDesugaredType(
recordType(hasDeclaration(cxxRecordDecl(optionalClass()))));
}

auto optionalOrAliasType() {
auto desugarsToOptionalOrDerivedType() {
return hasUnqualifiedDesugaredType(
recordType(hasDeclaration(optionalClass())));
recordType(hasDeclaration(cxxRecordDecl(optionalOrDerivedClass()))));
}

auto hasOptionalType() { return hasType(desugarsToOptionalType()); }

/// Matches any of the spellings of the optional types and sugar, aliases,
/// derived classes, etc.
auto hasOptionalOrDerivedType() {
return hasType(desugarsToOptionalOrDerivedType());
}

QualType getPublicType(const Expr *E) {
auto *Cast = dyn_cast<ImplicitCastExpr>(E->IgnoreParens());
if (Cast == nullptr || Cast->getCastKind() != CK_UncheckedDerivedToBase) {
QualType Ty = E->getType();
if (Ty->isPointerType())
return Ty->getPointeeType();
return Ty;
}

// Is the derived type that we're casting from the type of `*this`? In this
// special case, we can upcast to the base class even if the base is
// non-public.
bool CastingFromThis = isa<CXXThisExpr>(Cast->getSubExpr());

// Find the least-derived type in the path (i.e. the last entry in the list)
// that we can access.
const CXXBaseSpecifier *PublicBase = nullptr;
for (const CXXBaseSpecifier *Base : Cast->path()) {
if (Base->getAccessSpecifier() != AS_public && !CastingFromThis)
break;
PublicBase = Base;
CastingFromThis = false;
}

if (PublicBase != nullptr)
return PublicBase->getType();

// We didn't find any public type that we could cast to. There may be more
// casts in `getSubExpr()`, so recurse. (If there aren't any more casts, this
// will return the type of `getSubExpr()`.)
return getPublicType(Cast->getSubExpr());
}

/// Matches any of the spellings of the optional types and sugar, aliases, etc.
auto hasOptionalType() { return hasType(optionalOrAliasType()); }
// Returns the least-derived type for the receiver of `MCE` that
// `MCE.getImplicitObjectArgument()->IgnoreParentImpCasts()` can be downcast to.
// Effectively, we upcast until we reach a non-public base class, unless that
// base is a base of `*this`.
//
// This is needed to correctly match methods called on types derived from
// `std::optional`.
//
// Say we have a `struct Derived : public std::optional<int> {} d;` For a call
// `d.has_value()`, the `getImplicitObjectArgument()` looks like this:
//
// ImplicitCastExpr 'const std::__optional_storage_base<int>' lvalue
// | <UncheckedDerivedToBase (optional -> __optional_storage_base)>
// `-DeclRefExpr 'Derived' lvalue Var 'd' 'Derived'
//
// The type of the implicit object argument is `__optional_storage_base`
// (since this is the internal type that `has_value()` is declared on). If we
// call `IgnoreParenImpCasts()` on the implicit object argument, we get the
// `DeclRefExpr`, which has type `Derived`. Neither of these types is
// `optional`, and hence neither is sufficient for querying whether we are
// calling a method on `optional`.
//
// Instead, starting with the most derived type, we need to follow the chain of
// casts
QualType getPublicReceiverType(const CXXMemberCallExpr &MCE) {
return getPublicType(MCE.getImplicitObjectArgument());
}

AST_MATCHER_P(CXXMemberCallExpr, publicReceiverType,
ast_matchers::internal::Matcher<QualType>, InnerMatcher) {
return InnerMatcher.matches(getPublicReceiverType(Node), Finder, Builder);
}

auto isOptionalMemberCallWithNameMatcher(
ast_matchers::internal::Matcher<NamedDecl> matcher,
const std::optional<StatementMatcher> &Ignorable = std::nullopt) {
auto Exception = unless(Ignorable ? expr(anyOf(*Ignorable, cxxThisExpr()))
: cxxThisExpr());
return cxxMemberCallExpr(
on(expr(Exception,
anyOf(hasOptionalType(),
hasType(pointerType(pointee(optionalOrAliasType())))))),
callee(cxxMethodDecl(matcher)));
return cxxMemberCallExpr(Ignorable ? on(expr(unless(*Ignorable)))
: anything(),
publicReceiverType(desugarsToOptionalType()),
callee(cxxMethodDecl(matcher)));
}

auto isOptionalOperatorCallWithName(
Expand Down Expand Up @@ -129,49 +215,51 @@ auto inPlaceClass() {

auto isOptionalNulloptConstructor() {
return cxxConstructExpr(
hasOptionalType(),
hasDeclaration(cxxConstructorDecl(parameterCountIs(1),
hasParameter(0, hasNulloptType()))));
hasParameter(0, hasNulloptType()))),
hasOptionalOrDerivedType());
}

auto isOptionalInPlaceConstructor() {
return cxxConstructExpr(hasOptionalType(),
hasArgument(0, hasType(inPlaceClass())));
return cxxConstructExpr(hasArgument(0, hasType(inPlaceClass())),
hasOptionalOrDerivedType());
}

auto isOptionalValueOrConversionConstructor() {
return cxxConstructExpr(
hasOptionalType(),
unless(hasDeclaration(
cxxConstructorDecl(anyOf(isCopyConstructor(), isMoveConstructor())))),
argumentCountIs(1), hasArgument(0, unless(hasNulloptType())));
argumentCountIs(1), hasArgument(0, unless(hasNulloptType())),
hasOptionalOrDerivedType());
}

auto isOptionalValueOrConversionAssignment() {
return cxxOperatorCallExpr(
hasOverloadedOperatorName("="),
callee(cxxMethodDecl(ofClass(optionalClass()))),
callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
unless(hasDeclaration(cxxMethodDecl(
anyOf(isCopyAssignmentOperator(), isMoveAssignmentOperator())))),
argumentCountIs(2), hasArgument(1, unless(hasNulloptType())));
}

auto isOptionalNulloptAssignment() {
return cxxOperatorCallExpr(hasOverloadedOperatorName("="),
callee(cxxMethodDecl(ofClass(optionalClass()))),
argumentCountIs(2),
hasArgument(1, hasNulloptType()));
return cxxOperatorCallExpr(
hasOverloadedOperatorName("="),
callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))),
argumentCountIs(2), hasArgument(1, hasNulloptType()));
}

auto isStdSwapCall() {
return callExpr(callee(functionDecl(hasName("std::swap"))),
argumentCountIs(2), hasArgument(0, hasOptionalType()),
hasArgument(1, hasOptionalType()));
argumentCountIs(2),
hasArgument(0, hasOptionalOrDerivedType()),
hasArgument(1, hasOptionalOrDerivedType()));
}

auto isStdForwardCall() {
return callExpr(callee(functionDecl(hasName("std::forward"))),
argumentCountIs(1), hasArgument(0, hasOptionalType()));
argumentCountIs(1),
hasArgument(0, hasOptionalOrDerivedType()));
}

constexpr llvm::StringLiteral ValueOrCallID = "ValueOrCall";
Expand Down Expand Up @@ -212,8 +300,9 @@ auto isValueOrNotEqX() {
}

auto isCallReturningOptional() {
return callExpr(hasType(qualType(anyOf(
optionalOrAliasType(), referenceType(pointee(optionalOrAliasType()))))));
return callExpr(hasType(qualType(
anyOf(desugarsToOptionalOrDerivedType(),
referenceType(pointee(desugarsToOptionalOrDerivedType()))))));
}

template <typename L, typename R>
Expand Down Expand Up @@ -275,28 +364,23 @@ BoolValue *getHasValue(Environment &Env, RecordStorageLocation *OptionalLoc) {
return HasValueVal;
}

/// Returns true if and only if `Type` is an optional type.
bool isOptionalType(QualType Type) {
if (!Type->isRecordType())
return false;
const CXXRecordDecl *D = Type->getAsCXXRecordDecl();
return D != nullptr && hasOptionalClassName(*D);
QualType valueTypeFromOptionalDecl(const CXXRecordDecl &RD) {
auto &CTSD = cast<ClassTemplateSpecializationDecl>(RD);
return CTSD.getTemplateArgs()[0].getAsType();
}

/// Returns the number of optional wrappers in `Type`.
///
/// For example, if `Type` is `optional<optional<int>>`, the result of this
/// function will be 2.
int countOptionalWrappers(const ASTContext &ASTCtx, QualType Type) {
if (!isOptionalType(Type))
const CXXRecordDecl *Optional =
getOptionalBaseClass(Type->getAsCXXRecordDecl());
if (Optional == nullptr)
return 0;
return 1 + countOptionalWrappers(
ASTCtx,
cast<ClassTemplateSpecializationDecl>(Type->getAsRecordDecl())
->getTemplateArgs()
.get(0)
.getAsType()
.getDesugaredType(ASTCtx));
valueTypeFromOptionalDecl(*Optional).getDesugaredType(ASTCtx));
}

StorageLocation *getLocBehindPossiblePointer(const Expr &E,
Expand Down Expand Up @@ -843,13 +927,7 @@ auto buildDiagnoseMatchSwitch(

ast_matchers::DeclarationMatcher
UncheckedOptionalAccessModel::optionalClassDecl() {
return optionalClass();
}

static QualType valueTypeFromOptionalType(QualType OptionalTy) {
auto *CTSD =
cast<ClassTemplateSpecializationDecl>(OptionalTy->getAsCXXRecordDecl());
return CTSD->getTemplateArgs()[0].getAsType();
return cxxRecordDecl(optionalClass());
}

UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
Expand All @@ -858,9 +936,11 @@ UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx,
TransferMatchSwitch(buildTransferMatchSwitch()) {
Env.getDataflowAnalysisContext().setSyntheticFieldCallback(
[&Ctx](QualType Ty) -> llvm::StringMap<QualType> {
if (!isOptionalType(Ty))
const CXXRecordDecl *Optional =
getOptionalBaseClass(Ty->getAsCXXRecordDecl());
if (Optional == nullptr)
return {};
return {{"value", valueTypeFromOptionalType(Ty)},
return {{"value", valueTypeFromOptionalDecl(*Optional)},
{"has_value", Ctx.BoolTy}};
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3383,6 +3383,66 @@ TEST_P(UncheckedOptionalAccessTest, LambdaCaptureStateNotPropagated) {
}
)");
}

TEST_P(UncheckedOptionalAccessTest, ClassDerivedFromOptional) {
ExpectDiagnosticsFor(R"(
#include "unchecked_optional_access_test.h"
struct Derived : public $ns::$optional<int> {};
void target(Derived opt) {
*opt; // [[unsafe]]
if (opt.has_value())
*opt;
// The same thing, but with a pointer receiver.
Derived *popt = &opt;
**popt; // [[unsafe]]
if (popt->has_value())
**popt;
}
)");
}

TEST_P(UncheckedOptionalAccessTest, ClassTemplateDerivedFromOptional) {
ExpectDiagnosticsFor(R"(
#include "unchecked_optional_access_test.h"
template <class T>
struct Derived : public $ns::$optional<T> {};
void target(Derived<int> opt) {
*opt; // [[unsafe]]
if (opt.has_value())
*opt;
// The same thing, but with a pointer receiver.
Derived<int> *popt = &opt;
**popt; // [[unsafe]]
if (popt->has_value())
**popt;
}
)");
}

TEST_P(UncheckedOptionalAccessTest, ClassDerivedPrivatelyFromOptional) {
// Classes that derive privately from optional can themselves still call
// member functions of optional. Check that we model the optional correctly
// in this situation.
ExpectDiagnosticsFor(R"(
#include "unchecked_optional_access_test.h"
struct Derived : private $ns::$optional<int> {
void Method() {
**this; // [[unsafe]]
if (this->has_value())
**this;
}
};
)",
ast_matchers::hasName("Method"));
}

// FIXME: Add support for:
// - constructors (copy, move)
// - assignment operators (default, copy, move)
Expand Down

0 comments on commit d712c5e

Please sign in to comment.