diff --git a/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp b/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp index 8af12139d181f..21913d5809a6e 100644 --- a/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp +++ b/clang/lib/Analysis/FlowSensitive/Models/UncheckedStatusOrAccessModel.cpp @@ -168,6 +168,18 @@ static auto isNotOkStatusCall() { "::absl::UnimplementedError", "::absl::UnknownError")))); } +static auto isPointerComparisonOperatorCall(std::string operator_name) { + using namespace ::clang::ast_matchers; // NOLINT: Too many names + return binaryOperator( + hasOperatorName(operator_name), + hasLHS( + anyOf(hasType(hasCanonicalType(pointerType(pointee(statusOrType())))), + hasType(hasCanonicalType(pointerType(pointee(statusType())))))), + hasRHS(anyOf( + hasType(hasCanonicalType(pointerType(pointee(statusOrType())))), + hasType(hasCanonicalType(pointerType(pointee(statusType()))))))); +} + static auto buildDiagnoseMatchSwitch(const UncheckedStatusOrAccessModelOptions &Options) { return CFGMatchSwitchBuilder(Expr)) + return dyn_cast(&PointerVal->getPointeeLoc()); + return nullptr; +} + +static BoolValue *evaluatePointerEquality(const Expr *LhsExpr, + const Expr *RhsExpr, + Environment &Env) { + assert(LhsExpr->getType()->isPointerType()); + assert(RhsExpr->getType()->isPointerType()); + RecordStorageLocation *LhsStatusLoc = nullptr; + RecordStorageLocation *RhsStatusLoc = nullptr; + if (isStatusOrType(LhsExpr->getType()->getPointeeType()) && + isStatusOrType(RhsExpr->getType()->getPointeeType())) { + auto *LhsStatusOrLoc = getPointeeLocation(*LhsExpr, Env); + auto *RhsStatusOrLoc = getPointeeLocation(*RhsExpr, Env); + if (LhsStatusOrLoc == nullptr || RhsStatusOrLoc == nullptr) + return nullptr; + LhsStatusLoc = &locForStatus(*LhsStatusOrLoc); + RhsStatusLoc = &locForStatus(*RhsStatusOrLoc); + } else if (isStatusType(LhsExpr->getType()->getPointeeType()) && + isStatusType(RhsExpr->getType()->getPointeeType())) { + LhsStatusLoc = getPointeeLocation(*LhsExpr, Env); + RhsStatusLoc = getPointeeLocation(*RhsExpr, Env); + } + if (LhsStatusLoc == nullptr || RhsStatusLoc == nullptr) + return nullptr; + auto &LhsOkVal = valForOk(*LhsStatusLoc, Env); + auto &RhsOkVal = valForOk(*RhsStatusLoc, Env); + auto &Res = Env.makeAtomicBoolValue(); + auto &A = Env.arena(); + Env.assume(A.makeImplies( + Res.formula(), A.makeEquals(LhsOkVal.formula(), RhsOkVal.formula()))); + return &Res; +} + +static void transferPointerComparisonOperator(const BinaryOperator *Expr, + LatticeTransferState &State, + bool IsNegative) { + auto *LhsAndRhsVal = + evaluatePointerEquality(Expr->getLHS(), Expr->getRHS(), State.Env); + if (LhsAndRhsVal == nullptr) + return; + + if (IsNegative) + State.Env.setValue(*Expr, State.Env.makeNot(*LhsAndRhsVal)); + else + State.Env.setValue(*Expr, *LhsAndRhsVal); +} + static void transferOkStatusCall(const CallExpr *Expr, const MatchFinder::MatchResult &, LatticeTransferState &State) { @@ -488,6 +552,20 @@ buildTransferMatchSwitch(ASTContext &Ctx, transferComparisonOperator(Expr, State, /*IsNegative=*/true); }) + .CaseOfCFGStmt( + isPointerComparisonOperatorCall("=="), + [](const BinaryOperator *Expr, const MatchFinder::MatchResult &, + LatticeTransferState &State) { + transferPointerComparisonOperator(Expr, State, + /*IsNegative=*/false); + }) + .CaseOfCFGStmt( + isPointerComparisonOperatorCall("!="), + [](const BinaryOperator *Expr, const MatchFinder::MatchResult &, + LatticeTransferState &State) { + transferPointerComparisonOperator(Expr, State, + /*IsNegative=*/true); + }) .CaseOfCFGStmt(isOkStatusCall(), transferOkStatusCall) .CaseOfCFGStmt(isNotOkStatusCall(), transferNotOkStatusCall) .Build(); diff --git a/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp b/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp index 2676dab7fd904..62e456ad07bdb 100644 --- a/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp +++ b/clang/unittests/Analysis/FlowSensitive/UncheckedStatusOrAccessModelTestFixture.cpp @@ -2858,6 +2858,63 @@ TEST_P(UncheckedStatusOrAccessModelTest, EqualityCheck) { )cc"); } +TEST_P(UncheckedStatusOrAccessModelTest, PointerEqualityCheck) { + ExpectDiagnosticsFor( + R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT* x, STATUSOR_INT* y) { + if (x->ok()) { + if (x == y) + y->value(); + else + y->value(); // [[unsafe]] + } + } + )cc"); + ExpectDiagnosticsFor( + R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUSOR_INT* x, STATUSOR_INT* y) { + if (x->ok()) { + if (x != y) + y->value(); // [[unsafe]] + else + y->value(); + } + } + )cc"); + ExpectDiagnosticsFor( + R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUS* x, STATUS* y) { + auto sor = Make(); + if (x->ok()) { + if (x == y && sor.status() == *y) + sor.value(); + else + sor.value(); // [[unsafe]] + } + } + )cc"); + ExpectDiagnosticsFor( + R"cc( +#include "unchecked_statusor_access_test_defs.h" + + void target(STATUS* x, STATUS* y) { + auto sor = Make(); + if (x->ok()) { + if (x != y) + sor.value(); // [[unsafe]] + else if (sor.status() == *y) + sor.value(); + } + } + )cc"); +} + } // namespace std::string