diff --git a/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp b/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp index c917c8e8c11ba..f8439d875d8c7 100644 --- a/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp +++ b/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp @@ -351,6 +351,25 @@ static auto isAssertionResultConstructFromBoolCall() { hasArgument(0, hasType(booleanType()))); } +static auto isStatusOrReturningCall() { + using namespace ::clang::ast_matchers; // NOLINT: Too many names + return callExpr( + callee(functionDecl(returns(possiblyReferencedStatusOrType())))); +} + +static auto isStatusOrPtrReturningCall() { + using namespace ::clang::ast_matchers; // NOLINT: Too many names + return callExpr(callee(functionDecl(returns(hasUnqualifiedDesugaredType( + pointerType(pointee(possiblyReferencedStatusOrType()))))))); +} + +static auto isStatusPtrReturningCall() { + using namespace ::clang::ast_matchers; // NOLINT: Too many names + return callExpr(callee(functionDecl(returns(hasUnqualifiedDesugaredType( + pointerType(pointee(hasUnqualifiedDesugaredType( + recordType(hasDeclaration(statusClass())))))))))); +} + static auto buildDiagnoseMatchSwitch(const UncheckedStatusOrAccessModelOptions &Options) { return CFGMatchSwitchBuildergetSyntheticField("value"))); } +static void transferStatusOrPtrReturningCall(const CallExpr *Expr, + const MatchFinder::MatchResult &, + LatticeTransferState &State) { + PointerValue *PointerVal = + dyn_cast_or_null(State.Env.getValue(*Expr)); + if (!PointerVal) { + PointerVal = cast(State.Env.createValue(Expr->getType())); + State.Env.setValue(*Expr, *PointerVal); + } + + auto *RecordLoc = + dyn_cast_or_null(&PointerVal->getPointeeLoc()); + if (RecordLoc != nullptr && + State.Env.getValue(locForOk(locForStatus(*RecordLoc))) == nullptr) + initializeStatusOr(*RecordLoc, State.Env); +} + +static void transferStatusPtrReturningCall(const CallExpr *Expr, + const MatchFinder::MatchResult &, + LatticeTransferState &State) { + PointerValue *PointerVal = + dyn_cast_or_null(State.Env.getValue(*Expr)); + if (!PointerVal) { + PointerVal = cast(State.Env.createValue(Expr->getType())); + State.Env.setValue(*Expr, *PointerVal); + } + + auto *RecordLoc = + dyn_cast_or_null(&PointerVal->getPointeeLoc()); + if (RecordLoc != nullptr && + State.Env.getValue(locForOk(*RecordLoc)) == nullptr) + initializeStatus(*RecordLoc, State.Env); +} + static RecordStorageLocation * getSmartPtrLikeStorageLocation(const Expr &E, const Environment &Env) { if (!E.isPRValue()) @@ -1209,6 +1262,18 @@ buildTransferMatchSwitch(ASTContext &Ctx, transferNonConstMemberCall) .CaseOfCFGStmt(isNonConstMemberOperatorCall(), transferNonConstMemberOperatorCall) + // N.B. this has to be after transferConstMemberCall, otherwise we would + // always return a fresh RecordStorageLocation for the StatusOr. + .CaseOfCFGStmt(isStatusOrReturningCall(), + [](const CallExpr *Expr, + const MatchFinder::MatchResult &, + LatticeTransferState &State) { + transferStatusOrReturningCall(Expr, State); + }) + .CaseOfCFGStmt(isStatusOrPtrReturningCall(), + transferStatusOrPtrReturningCall) + .CaseOfCFGStmt(isStatusPtrReturningCall(), + transferStatusPtrReturningCall) // N.B. These need to come after all other CXXConstructExpr. // These are there to make sure that every Status and StatusOr object // have their ok boolean initialized when constructed. If we were to diff --git a/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp b/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp index cd7353c62f537..c012d0527870b 100644 --- a/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp +++ b/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp @@ -3840,6 +3840,56 @@ TEST_P(UncheckedStatusOrAccessModelTest, NestedStatusOrInStatusOrStruct) { )cc"); } +TEST_P(UncheckedStatusOrAccessModelTest, StatusOrPtrReference) { + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + const STATUSOR_INT* foo(); + + void target() { + const auto& sor = foo(); + if (sor->ok()) sor->value(); + } + )cc"); + + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + using StatusOrPtr = const STATUSOR_INT*; + StatusOrPtr foo(); + + void target() { + const auto& sor = foo(); + if (sor->ok()) sor->value(); + } + )cc"); +} + +TEST_P(UncheckedStatusOrAccessModelTest, StatusPtrReference) { + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + const STATUS* foo(); + + void target(STATUSOR_INT sor) { + const auto& s = foo(); + if (s->ok() && *s == sor.status()) sor.value(); + } + )cc"); + + ExpectDiagnosticsFor(R"cc( +#include "unchecked_statusor_access_test_defs.h" + + using StatusPtr = const STATUS*; + StatusPtr foo(); + + void target(STATUSOR_INT sor) { + const auto& s = foo(); + if (s->ok() && *s == sor.status()) sor.value(); + } + )cc"); +} + } // namespace std::string