Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ static auto isNotOkStatusCall() {
"::absl::UnimplementedError", "::absl::UnknownError"))));
}

static auto isPointerComparisonOperatorCall(std::string operator_name) {
using namespace ::clang::ast_matchers; // NOLINT: Too many names
return binaryOperator(hasOperatorName(operator_name),
hasLHS(hasType(hasCanonicalType(pointerType(
pointee(anyOf(statusOrType(), statusType())))))),
hasRHS(hasType(hasCanonicalType(pointerType(
pointee(anyOf(statusOrType(), statusType())))))));
}

static auto
buildDiagnoseMatchSwitch(const UncheckedStatusOrAccessModelOptions &Options) {
return CFGMatchSwitchBuilder<const Environment,
Expand Down Expand Up @@ -438,6 +447,58 @@ static void transferComparisonOperator(const CXXOperatorCallExpr *Expr,
State.Env.setValue(*Expr, *LhsAndRhsVal);
}

static RecordStorageLocation *getPointeeLocation(const Expr &Expr,
Environment &Env) {
if (auto *PointerVal = Env.get<PointerValue>(Expr))
return dyn_cast<RecordStorageLocation>(&PointerVal->getPointeeLoc());
return nullptr;
}

static BoolValue *evaluatePointerEquality(const Expr *LhsExpr,
const Expr *RhsExpr,
Environment &Env) {
assert(LhsExpr->getType()->isPointerType());
assert(RhsExpr->getType()->isPointerType());
RecordStorageLocation *LhsStatusLoc = nullptr;
RecordStorageLocation *RhsStatusLoc = nullptr;
if (isStatusOrType(LhsExpr->getType()->getPointeeType()) &&
isStatusOrType(RhsExpr->getType()->getPointeeType())) {
auto *LhsStatusOrLoc = getPointeeLocation(*LhsExpr, Env);
auto *RhsStatusOrLoc = getPointeeLocation(*RhsExpr, Env);
if (LhsStatusOrLoc == nullptr || RhsStatusOrLoc == nullptr)
return nullptr;
LhsStatusLoc = &locForStatus(*LhsStatusOrLoc);
RhsStatusLoc = &locForStatus(*RhsStatusOrLoc);
} else if (isStatusType(LhsExpr->getType()->getPointeeType()) &&
isStatusType(RhsExpr->getType()->getPointeeType())) {
LhsStatusLoc = getPointeeLocation(*LhsExpr, Env);
RhsStatusLoc = getPointeeLocation(*RhsExpr, Env);
}
if (LhsStatusLoc == nullptr || RhsStatusLoc == nullptr)
return nullptr;
auto &LhsOkVal = valForOk(*LhsStatusLoc, Env);
auto &RhsOkVal = valForOk(*RhsStatusLoc, Env);
auto &Res = Env.makeAtomicBoolValue();
auto &A = Env.arena();
Env.assume(A.makeImplies(
Res.formula(), A.makeEquals(LhsOkVal.formula(), RhsOkVal.formula())));
return &Res;
}

static void transferPointerComparisonOperator(const BinaryOperator *Expr,
LatticeTransferState &State,
bool IsNegative) {
auto *LhsAndRhsVal =
evaluatePointerEquality(Expr->getLHS(), Expr->getRHS(), State.Env);
if (LhsAndRhsVal == nullptr)
return;

if (IsNegative)
State.Env.setValue(*Expr, State.Env.makeNot(*LhsAndRhsVal));
else
State.Env.setValue(*Expr, *LhsAndRhsVal);
}

static void transferOkStatusCall(const CallExpr *Expr,
const MatchFinder::MatchResult &,
LatticeTransferState &State) {
Expand Down Expand Up @@ -482,6 +543,20 @@ buildTransferMatchSwitch(ASTContext &Ctx,
transferComparisonOperator(Expr, State,
/*IsNegative=*/true);
})
.CaseOfCFGStmt<BinaryOperator>(
isPointerComparisonOperatorCall("=="),
[](const BinaryOperator *Expr, const MatchFinder::MatchResult &,
LatticeTransferState &State) {
transferPointerComparisonOperator(Expr, State,
/*IsNegative=*/false);
})
.CaseOfCFGStmt<BinaryOperator>(
isPointerComparisonOperatorCall("!="),
[](const BinaryOperator *Expr, const MatchFinder::MatchResult &,
LatticeTransferState &State) {
transferPointerComparisonOperator(Expr, State,
/*IsNegative=*/true);
})
.CaseOfCFGStmt<CallExpr>(isOkStatusCall(), transferOkStatusCall)
.CaseOfCFGStmt<CallExpr>(isNotOkStatusCall(), transferNotOkStatusCall)
.Build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2871,6 +2871,63 @@ TEST_P(UncheckedStatusOrAccessModelTest, EqualityCheck) {
)cc");
}

TEST_P(UncheckedStatusOrAccessModelTest, PointerEqualityCheck) {
ExpectDiagnosticsFor(
R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT* x, STATUSOR_INT* y) {
if (x->ok()) {
if (x == y)
y->value();
else
y->value(); // [[unsafe]]
}
}
)cc");
ExpectDiagnosticsFor(
R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUSOR_INT* x, STATUSOR_INT* y) {
if (x->ok()) {
if (x != y)
y->value(); // [[unsafe]]
else
y->value();
}
}
)cc");
ExpectDiagnosticsFor(
R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUS* x, STATUS* y) {
auto sor = Make<STATUSOR_INT>();
if (x->ok()) {
if (x == y && sor.status() == *y)
sor.value();
else
sor.value(); // [[unsafe]]
}
}
)cc");
ExpectDiagnosticsFor(
R"cc(
#include "unchecked_statusor_access_test_defs.h"

void target(STATUS* x, STATUS* y) {
auto sor = Make<STATUSOR_INT>();
if (x->ok()) {
if (x != y)
sor.value(); // [[unsafe]]
else if (sor.status() == *y)
sor.value();
}
}
)cc");
}

} // namespace

std::string
Expand Down
Loading