diff --git a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp index 1d31b22b6d25f..dbf4878622eba 100644 --- a/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp +++ b/clang/lib/Analysis/FlowSensitive/Models/UncheckedOptionalAccessModel.cpp @@ -64,39 +64,125 @@ static bool hasOptionalClassName(const CXXRecordDecl &RD) { return false; } +static const CXXRecordDecl *getOptionalBaseClass(const CXXRecordDecl *RD) { + if (RD == nullptr) + return nullptr; + if (hasOptionalClassName(*RD)) + return RD; + + if (!RD->hasDefinition()) + return nullptr; + + for (const CXXBaseSpecifier &Base : RD->bases()) + if (const CXXRecordDecl *BaseClass = + getOptionalBaseClass(Base.getType()->getAsCXXRecordDecl())) + return BaseClass; + + return nullptr; +} + namespace { using namespace ::clang::ast_matchers; using LatticeTransferState = TransferState; -AST_MATCHER(CXXRecordDecl, hasOptionalClassNameMatcher) { - return hasOptionalClassName(Node); +AST_MATCHER(CXXRecordDecl, optionalClass) { return hasOptionalClassName(Node); } + +AST_MATCHER(CXXRecordDecl, optionalOrDerivedClass) { + return getOptionalBaseClass(&Node) != nullptr; } -DeclarationMatcher optionalClass() { - return classTemplateSpecializationDecl( - hasOptionalClassNameMatcher(), - hasTemplateArgument(0, refersToType(type().bind("T")))); +auto desugarsToOptionalType() { + return hasUnqualifiedDesugaredType( + recordType(hasDeclaration(cxxRecordDecl(optionalClass())))); } -auto optionalOrAliasType() { +auto desugarsToOptionalOrDerivedType() { return hasUnqualifiedDesugaredType( - recordType(hasDeclaration(optionalClass()))); + recordType(hasDeclaration(cxxRecordDecl(optionalOrDerivedClass())))); +} + +auto hasOptionalType() { return hasType(desugarsToOptionalType()); } + +/// Matches any of the spellings of the optional types and sugar, aliases, +/// derived classes, etc. +auto hasOptionalOrDerivedType() { + return hasType(desugarsToOptionalOrDerivedType()); +} + +QualType getPublicType(const Expr *E) { + auto *Cast = dyn_cast(E->IgnoreParens()); + if (Cast == nullptr || Cast->getCastKind() != CK_UncheckedDerivedToBase) { + QualType Ty = E->getType(); + if (Ty->isPointerType()) + return Ty->getPointeeType(); + return Ty; + } + + // Is the derived type that we're casting from the type of `*this`? In this + // special case, we can upcast to the base class even if the base is + // non-public. + bool CastingFromThis = isa(Cast->getSubExpr()); + + // Find the least-derived type in the path (i.e. the last entry in the list) + // that we can access. + const CXXBaseSpecifier *PublicBase = nullptr; + for (const CXXBaseSpecifier *Base : Cast->path()) { + if (Base->getAccessSpecifier() != AS_public && !CastingFromThis) + break; + PublicBase = Base; + CastingFromThis = false; + } + + if (PublicBase != nullptr) + return PublicBase->getType(); + + // We didn't find any public type that we could cast to. There may be more + // casts in `getSubExpr()`, so recurse. (If there aren't any more casts, this + // will return the type of `getSubExpr()`.) + return getPublicType(Cast->getSubExpr()); } -/// Matches any of the spellings of the optional types and sugar, aliases, etc. -auto hasOptionalType() { return hasType(optionalOrAliasType()); } +// Returns the least-derived type for the receiver of `MCE` that +// `MCE.getImplicitObjectArgument()->IgnoreParentImpCasts()` can be downcast to. +// Effectively, we upcast until we reach a non-public base class, unless that +// base is a base of `*this`. +// +// This is needed to correctly match methods called on types derived from +// `std::optional`. +// +// Say we have a `struct Derived : public std::optional {} d;` For a call +// `d.has_value()`, the `getImplicitObjectArgument()` looks like this: +// +// ImplicitCastExpr 'const std::__optional_storage_base' lvalue +// | __optional_storage_base)> +// `-DeclRefExpr 'Derived' lvalue Var 'd' 'Derived' +// +// The type of the implicit object argument is `__optional_storage_base` +// (since this is the internal type that `has_value()` is declared on). If we +// call `IgnoreParenImpCasts()` on the implicit object argument, we get the +// `DeclRefExpr`, which has type `Derived`. Neither of these types is +// `optional`, and hence neither is sufficient for querying whether we are +// calling a method on `optional`. +// +// Instead, starting with the most derived type, we need to follow the chain of +// casts +QualType getPublicReceiverType(const CXXMemberCallExpr &MCE) { + return getPublicType(MCE.getImplicitObjectArgument()); +} + +AST_MATCHER_P(CXXMemberCallExpr, publicReceiverType, + ast_matchers::internal::Matcher, InnerMatcher) { + return InnerMatcher.matches(getPublicReceiverType(Node), Finder, Builder); +} auto isOptionalMemberCallWithNameMatcher( ast_matchers::internal::Matcher matcher, const std::optional &Ignorable = std::nullopt) { - auto Exception = unless(Ignorable ? expr(anyOf(*Ignorable, cxxThisExpr())) - : cxxThisExpr()); - return cxxMemberCallExpr( - on(expr(Exception, - anyOf(hasOptionalType(), - hasType(pointerType(pointee(optionalOrAliasType())))))), - callee(cxxMethodDecl(matcher))); + return cxxMemberCallExpr(Ignorable ? on(expr(unless(*Ignorable))) + : anything(), + publicReceiverType(desugarsToOptionalType()), + callee(cxxMethodDecl(matcher))); } auto isOptionalOperatorCallWithName( @@ -129,49 +215,51 @@ auto inPlaceClass() { auto isOptionalNulloptConstructor() { return cxxConstructExpr( - hasOptionalType(), hasDeclaration(cxxConstructorDecl(parameterCountIs(1), - hasParameter(0, hasNulloptType())))); + hasParameter(0, hasNulloptType()))), + hasOptionalOrDerivedType()); } auto isOptionalInPlaceConstructor() { - return cxxConstructExpr(hasOptionalType(), - hasArgument(0, hasType(inPlaceClass()))); + return cxxConstructExpr(hasArgument(0, hasType(inPlaceClass())), + hasOptionalOrDerivedType()); } auto isOptionalValueOrConversionConstructor() { return cxxConstructExpr( - hasOptionalType(), unless(hasDeclaration( cxxConstructorDecl(anyOf(isCopyConstructor(), isMoveConstructor())))), - argumentCountIs(1), hasArgument(0, unless(hasNulloptType()))); + argumentCountIs(1), hasArgument(0, unless(hasNulloptType())), + hasOptionalOrDerivedType()); } auto isOptionalValueOrConversionAssignment() { return cxxOperatorCallExpr( hasOverloadedOperatorName("="), - callee(cxxMethodDecl(ofClass(optionalClass()))), + callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))), unless(hasDeclaration(cxxMethodDecl( anyOf(isCopyAssignmentOperator(), isMoveAssignmentOperator())))), argumentCountIs(2), hasArgument(1, unless(hasNulloptType()))); } auto isOptionalNulloptAssignment() { - return cxxOperatorCallExpr(hasOverloadedOperatorName("="), - callee(cxxMethodDecl(ofClass(optionalClass()))), - argumentCountIs(2), - hasArgument(1, hasNulloptType())); + return cxxOperatorCallExpr( + hasOverloadedOperatorName("="), + callee(cxxMethodDecl(ofClass(optionalOrDerivedClass()))), + argumentCountIs(2), hasArgument(1, hasNulloptType())); } auto isStdSwapCall() { return callExpr(callee(functionDecl(hasName("std::swap"))), - argumentCountIs(2), hasArgument(0, hasOptionalType()), - hasArgument(1, hasOptionalType())); + argumentCountIs(2), + hasArgument(0, hasOptionalOrDerivedType()), + hasArgument(1, hasOptionalOrDerivedType())); } auto isStdForwardCall() { return callExpr(callee(functionDecl(hasName("std::forward"))), - argumentCountIs(1), hasArgument(0, hasOptionalType())); + argumentCountIs(1), + hasArgument(0, hasOptionalOrDerivedType())); } constexpr llvm::StringLiteral ValueOrCallID = "ValueOrCall"; @@ -212,8 +300,9 @@ auto isValueOrNotEqX() { } auto isCallReturningOptional() { - return callExpr(hasType(qualType(anyOf( - optionalOrAliasType(), referenceType(pointee(optionalOrAliasType())))))); + return callExpr(hasType(qualType( + anyOf(desugarsToOptionalOrDerivedType(), + referenceType(pointee(desugarsToOptionalOrDerivedType())))))); } template @@ -275,12 +364,9 @@ BoolValue *getHasValue(Environment &Env, RecordStorageLocation *OptionalLoc) { return HasValueVal; } -/// Returns true if and only if `Type` is an optional type. -bool isOptionalType(QualType Type) { - if (!Type->isRecordType()) - return false; - const CXXRecordDecl *D = Type->getAsCXXRecordDecl(); - return D != nullptr && hasOptionalClassName(*D); +QualType valueTypeFromOptionalDecl(const CXXRecordDecl &RD) { + auto &CTSD = cast(RD); + return CTSD.getTemplateArgs()[0].getAsType(); } /// Returns the number of optional wrappers in `Type`. @@ -288,15 +374,13 @@ bool isOptionalType(QualType Type) { /// For example, if `Type` is `optional>`, the result of this /// function will be 2. int countOptionalWrappers(const ASTContext &ASTCtx, QualType Type) { - if (!isOptionalType(Type)) + const CXXRecordDecl *Optional = + getOptionalBaseClass(Type->getAsCXXRecordDecl()); + if (Optional == nullptr) return 0; return 1 + countOptionalWrappers( ASTCtx, - cast(Type->getAsRecordDecl()) - ->getTemplateArgs() - .get(0) - .getAsType() - .getDesugaredType(ASTCtx)); + valueTypeFromOptionalDecl(*Optional).getDesugaredType(ASTCtx)); } StorageLocation *getLocBehindPossiblePointer(const Expr &E, @@ -843,13 +927,7 @@ auto buildDiagnoseMatchSwitch( ast_matchers::DeclarationMatcher UncheckedOptionalAccessModel::optionalClassDecl() { - return optionalClass(); -} - -static QualType valueTypeFromOptionalType(QualType OptionalTy) { - auto *CTSD = - cast(OptionalTy->getAsCXXRecordDecl()); - return CTSD->getTemplateArgs()[0].getAsType(); + return cxxRecordDecl(optionalClass()); } UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx, @@ -858,9 +936,11 @@ UncheckedOptionalAccessModel::UncheckedOptionalAccessModel(ASTContext &Ctx, TransferMatchSwitch(buildTransferMatchSwitch()) { Env.getDataflowAnalysisContext().setSyntheticFieldCallback( [&Ctx](QualType Ty) -> llvm::StringMap { - if (!isOptionalType(Ty)) + const CXXRecordDecl *Optional = + getOptionalBaseClass(Ty->getAsCXXRecordDecl()); + if (Optional == nullptr) return {}; - return {{"value", valueTypeFromOptionalType(Ty)}, + return {{"value", valueTypeFromOptionalDecl(*Optional)}, {"has_value", Ctx.BoolTy}}; }); } diff --git a/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp b/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp index b6e4973fd7cb2..9430730004dbd 100644 --- a/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp +++ b/clang/unittests/Analysis/FlowSensitive/UncheckedOptionalAccessModelTest.cpp @@ -3383,6 +3383,66 @@ TEST_P(UncheckedOptionalAccessTest, LambdaCaptureStateNotPropagated) { } )"); } + +TEST_P(UncheckedOptionalAccessTest, ClassDerivedFromOptional) { + ExpectDiagnosticsFor(R"( + #include "unchecked_optional_access_test.h" + + struct Derived : public $ns::$optional {}; + + void target(Derived opt) { + *opt; // [[unsafe]] + if (opt.has_value()) + *opt; + + // The same thing, but with a pointer receiver. + Derived *popt = &opt; + **popt; // [[unsafe]] + if (popt->has_value()) + **popt; + } + )"); +} + +TEST_P(UncheckedOptionalAccessTest, ClassTemplateDerivedFromOptional) { + ExpectDiagnosticsFor(R"( + #include "unchecked_optional_access_test.h" + + template + struct Derived : public $ns::$optional {}; + + void target(Derived opt) { + *opt; // [[unsafe]] + if (opt.has_value()) + *opt; + + // The same thing, but with a pointer receiver. + Derived *popt = &opt; + **popt; // [[unsafe]] + if (popt->has_value()) + **popt; + } + )"); +} + +TEST_P(UncheckedOptionalAccessTest, ClassDerivedPrivatelyFromOptional) { + // Classes that derive privately from optional can themselves still call + // member functions of optional. Check that we model the optional correctly + // in this situation. + ExpectDiagnosticsFor(R"( + #include "unchecked_optional_access_test.h" + + struct Derived : private $ns::$optional { + void Method() { + **this; // [[unsafe]] + if (this->has_value()) + **this; + } + }; + )", + ast_matchers::hasName("Method")); +} + // FIXME: Add support for: // - constructors (copy, move) // - assignment operators (default, copy, move)