Skip to content

Commit

Permalink
Merge pull request #16690 from MathiasVP/better-guards
Browse files Browse the repository at this point in the history
C++: Fix missing results for `comparesEq` in `IRGuardCondition`
  • Loading branch information
jketema committed Jun 10, 2024
2 parents 06aa266 + 7819cc1 commit 000a81f
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 51 deletions.
101 changes: 70 additions & 31 deletions cpp/ql/lib/semmle/code/cpp/controlflow/IRGuards.qll
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,33 @@ cached
class IRGuardCondition extends Instruction {
Instruction branch;

/*
* An `IRGuardCondition` supports reasoning about four different kinds of
* relations:
* 1. A unary equality relation of the form `e == k`
* 2. A binary equality relation of the form `e1 == e2 + k`
* 3. A unary inequality relation of the form `e < k`
* 4. A binary inequality relation of the form `e1 < e2 + k`
*
* where `k` is a constant.
*
* Furthermore, the unary relations (i.e., case 1 and case 3) are also
* inferred from `switch` statement guards: equality relations are inferred
* from the unique `case` statement, if any, and inequality relations are
* inferred from the [case range](https://gcc.gnu.org/onlinedocs/gcc/Case-Ranges.html)
* gcc extension.
*
* The implementation of all four follows the same structure: Each relation
* has a cached user-facing predicate that. For example,
* `GuardCondition::comparesEq` calls `compares_eq`. This predicate has
* several cases that recursively decompose the relation to bring it to a
* canonical form (i.e., a relation of the form `e1 == e2 + k`). The base
* case for this relation (i.e., `simple_comparison_eq`) handles
* `CompareEQInstruction`s and `CompareNEInstruction`, and recursive
* predicates (e.g., `complex_eq`) rewrites larger expressions such as
* `e1 + k1 == e2 + k2` into canonical the form `e1 == e2 + (k2 - k1)`.
*/

cached
IRGuardCondition() { branch = getBranchForCondition(this) }

Expand Down Expand Up @@ -776,7 +803,9 @@ private predicate unary_compares_eq(
Instruction test, Operand op, int k, boolean areEqual, boolean inNonZeroCase, AbstractValue value
) {
/* The simple case where the test *is* the comparison so areEqual = testIsTrue xor eq. */
exists(AbstractValue v | unary_simple_comparison_eq(test, op, k, inNonZeroCase, v) |
exists(AbstractValue v |
unary_simple_comparison_eq(test, k, inNonZeroCase, v) and op.getDef() = test
|
areEqual = true and value = v
or
areEqual = false and value = v.getDualValue()
Expand Down Expand Up @@ -821,45 +850,55 @@ private predicate simple_comparison_eq(
value.(BooleanValue).getValue() = false
}

/**
* Holds if `test` is an instruction that is part of test that eventually is
* used in a conditional branch.
*/
private predicate relevantUnaryComparison(Instruction test) {
not test instanceof CompareInstruction and
exists(IRType type, ConditionalBranchInstruction branch |
type instanceof IRAddressType or type instanceof IRIntegerType
|
type = test.getResultIRType() and
branch.getCondition() = test
)
or
exists(LogicalNotInstruction logicalNot |
relevantUnaryComparison(logicalNot) and
test = logicalNot.getUnary()
)
}

/**
* Rearrange various simple comparisons into `op == k` form.
*/
private predicate unary_simple_comparison_eq(
Instruction test, Operand op, int k, boolean inNonZeroCase, AbstractValue value
Instruction test, int k, boolean inNonZeroCase, AbstractValue value
) {
exists(SwitchInstruction switch, CaseEdge case |
test = switch.getExpression() and
op.getDef() = test and
case = value.(MatchValue).getCase() and
exists(switch.getSuccessor(case)) and
case.getValue().toInt() = k and
inNonZeroCase = false
)
or
// There's no implicit CompareInstruction in files compiled as C since C
// doesn't have implicit boolean conversions. So instead we check whether
// there's a branch on a value of pointer or integer type.
relevantUnaryComparison(test) and
op.getDef() = test and
// Any instruction with an integral type could potentially be part of a
// check for nullness when used in a guard. So we include all integral
// typed instructions here. However, since some of these instructions are
// already included as guards in other cases, we exclude those here.
// These are instructions that compute a binary equality or inequality
// relation. For example, the following:
// ```cpp
// if(a == b + 42) { ... }
// ```
// generates the following IR:
// ```
// r1(glval<int>) = VariableAddress[a] :
// r2(int) = Load[a] : &:r1, m1
// r3(glval<int>) = VariableAddress[b] :
// r4(int) = Load[b] : &:r3, m2
// r5(int) = Constant[42] :
// r6(int) = Add : r4, r5
// r7(bool) = CompareEQ : r2, r6
// v1(void) = ConditionalBranch : r7
// ```
// and since `r7` is an integral typed instruction this predicate could
// include a case for when `r7` evaluates to true (in which case we would
// infer that `r6` was non-zero, and a case for when `r7` evaluates to false
// (in which case we would infer that `r6` was zero).
// However, since `a == b + 42` is already supported when reasoning about
// binary equalities we exclude those cases here.
not test.isGLValue() and
not simple_comparison_eq(test, _, _, _, _) and
not simple_comparison_lt(test, _, _, _) and
not test = any(SwitchInstruction switch).getExpression() and
(
test.getResultIRType() instanceof IRAddressType or
test.getResultIRType() instanceof IRIntegerType or
test.getResultIRType() instanceof IRBooleanType
) and
(
k = 1 and
value.(BooleanValue).getValue() = true and
Expand Down Expand Up @@ -913,7 +952,8 @@ private predicate compares_lt(

/** Holds if `op < k` evaluates to `isLt` given that `test` evaluates to `value`. */
private predicate compares_lt(Instruction test, Operand op, int k, boolean isLt, AbstractValue value) {
simple_comparison_lt(test, op, k, isLt, value)
unary_simple_comparison_lt(test, k, isLt, value) and
op.getDef() = test
or
complex_lt(test, op, k, isLt, value)
or
Expand Down Expand Up @@ -960,12 +1000,11 @@ private predicate simple_comparison_lt(CompareInstruction cmp, Operand left, Ope
}

/** Rearrange various simple comparisons into `op < k` form. */
private predicate simple_comparison_lt(
Instruction test, Operand op, int k, boolean isLt, AbstractValue value
private predicate unary_simple_comparison_lt(
Instruction test, int k, boolean isLt, AbstractValue value
) {
exists(SwitchInstruction switch, CaseEdge case |
test = switch.getExpression() and
op.getDef() = test and
case = value.(MatchValue).getCase() and
exists(switch.getSuccessor(case)) and
case.getMaxValue() > case.getMinValue()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ astGuardsCompare
| 34 | j >= 10+0 when ... < ... is false |
| 42 | 10 < j+1 when ... < ... is false |
| 42 | 10 >= j+1 when ... < ... is true |
| 42 | call to getABool != 0 when call to getABool is true |
| 42 | call to getABool == 0 when call to getABool is false |
| 42 | j < 10+0 when ... < ... is true |
| 42 | j >= 10+0 when ... < ... is false |
| 44 | 0 < z+0 when ... > ... is true |
Expand Down Expand Up @@ -537,6 +539,8 @@ astGuardsEnsure_const
| test.cpp:31:7:31:13 | ... == ... | test.cpp:31:7:31:7 | x | != | -1 | 34 | 34 |
| test.cpp:31:7:31:13 | ... == ... | test.cpp:31:7:31:7 | x | == | -1 | 30 | 30 |
| test.cpp:31:7:31:13 | ... == ... | test.cpp:31:7:31:7 | x | == | -1 | 31 | 32 |
| test.cpp:42:13:42:20 | call to getABool | test.cpp:42:13:42:20 | call to getABool | != | 0 | 43 | 45 |
| test.cpp:42:13:42:20 | call to getABool | test.cpp:42:13:42:20 | call to getABool | == | 0 | 53 | 53 |
irGuards
| test.c:7:9:7:13 | CompareGT: ... > ... |
| test.c:17:8:17:12 | CompareLT: ... < ... |
Expand Down Expand Up @@ -613,6 +617,8 @@ irGuardsCompare
| 34 | j >= 10+0 when CompareLT: ... < ... is false |
| 42 | 10 < j+1 when CompareLT: ... < ... is false |
| 42 | 10 >= j+1 when CompareLT: ... < ... is true |
| 42 | call to getABool != 0 when Call: call to getABool is true |
| 42 | call to getABool == 0 when Call: call to getABool is false |
| 42 | j < 10 when CompareLT: ... < ... is true |
| 42 | j < 10+0 when CompareLT: ... < ... is true |
| 42 | j >= 10 when CompareLT: ... < ... is false |
Expand Down Expand Up @@ -1081,3 +1087,5 @@ irGuardsEnsure_const
| test.cpp:31:7:31:13 | CompareEQ: ... == ... | test.cpp:31:7:31:7 | Load: x | != | -1 | 34 | 34 |
| test.cpp:31:7:31:13 | CompareEQ: ... == ... | test.cpp:31:7:31:7 | Load: x | == | -1 | 30 | 30 |
| test.cpp:31:7:31:13 | CompareEQ: ... == ... | test.cpp:31:7:31:7 | Load: x | == | -1 | 32 | 32 |
| test.cpp:42:13:42:20 | Call: call to getABool | test.cpp:42:13:42:20 | Call: call to getABool | != | 0 | 44 | 44 |
| test.cpp:42:13:42:20 | Call: call to getABool | test.cpp:42:13:42:20 | Call: call to getABool | == | 0 | 53 | 53 |
3 changes: 3 additions & 0 deletions cpp/ql/test/library-tests/controlflow/guards/Guards.expected
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,6 @@
| test.cpp:99:6:99:6 | f |
| test.cpp:105:6:105:14 | ... != ... |
| test.cpp:111:6:111:14 | ... != ... |
| test.cpp:122:9:122:9 | b |
| test.cpp:125:13:125:20 | ! ... |
| test.cpp:125:14:125:17 | call to safe |
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
| 34 | j >= 10+0 when ... < ... is false |
| 42 | 10 < j+1 when ... < ... is false |
| 42 | 10 >= j+1 when ... < ... is true |
| 42 | call to getABool != 0 when call to getABool is true |
| 42 | call to getABool == 0 when call to getABool is false |
| 42 | j < 10 when ... < ... is true |
| 42 | j < 10+0 when ... < ... is true |
| 42 | j >= 10 when ... < ... is false |
Expand Down Expand Up @@ -149,6 +151,13 @@
| 111 | 0.0 == i+0 when ... != ... is false |
| 111 | i != 0.0+0 when ... != ... is true |
| 111 | i == 0.0+0 when ... != ... is false |
| 122 | b != 0 when b is true |
| 122 | b == 0 when b is false |
| 125 | ! ... != 0 when ! ... is true |
| 125 | ! ... == 0 when ! ... is false |
| 125 | call to safe != 0 when ! ... is false |
| 125 | call to safe != 0 when call to safe is true |
| 125 | call to safe == 0 when call to safe is false |
| 126 | 1 != 0 when 1 is true |
| 126 | 1 != 0 when ... && ... is true |
| 126 | 1 == 0 when 1 is false |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,7 @@
| test.cpp:99:6:99:6 | f | true | 99 | 100 |
| test.cpp:105:6:105:14 | ... != ... | true | 105 | 106 |
| test.cpp:111:6:111:14 | ... != ... | true | 111 | 112 |
| test.cpp:122:9:122:9 | b | true | 123 | 125 |
| test.cpp:122:9:122:9 | b | true | 125 | 125 |
| test.cpp:125:13:125:20 | ! ... | true | 125 | 125 |
| test.cpp:125:14:125:17 | call to safe | false | 125 | 125 |
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,16 @@ unary
| test.cpp:31:7:31:13 | ... == ... | test.cpp:31:7:31:7 | x | != | -1 | 34 | 34 |
| test.cpp:31:7:31:13 | ... == ... | test.cpp:31:7:31:7 | x | == | -1 | 30 | 30 |
| test.cpp:31:7:31:13 | ... == ... | test.cpp:31:7:31:7 | x | == | -1 | 31 | 32 |
| test.cpp:42:13:42:20 | call to getABool | test.cpp:42:13:42:20 | call to getABool | != | 0 | 43 | 45 |
| test.cpp:42:13:42:20 | call to getABool | test.cpp:42:13:42:20 | call to getABool | == | 0 | 53 | 53 |
| test.cpp:61:10:61:10 | i | test.cpp:61:10:61:10 | i | == | 0 | 62 | 64 |
| test.cpp:61:10:61:10 | i | test.cpp:61:10:61:10 | i | == | 1 | 65 | 66 |
| test.cpp:74:10:74:10 | i | test.cpp:74:10:74:10 | i | < | 11 | 75 | 77 |
| test.cpp:74:10:74:10 | i | test.cpp:74:10:74:10 | i | < | 21 | 78 | 79 |
| test.cpp:74:10:74:10 | i | test.cpp:74:10:74:10 | i | >= | 0 | 75 | 77 |
| test.cpp:74:10:74:10 | i | test.cpp:74:10:74:10 | i | >= | 11 | 78 | 79 |
| test.cpp:93:6:93:6 | c | test.cpp:93:6:93:6 | c | != | 0 | 93 | 94 |
| test.cpp:122:9:122:9 | b | test.cpp:122:9:122:9 | b | != | 0 | 123 | 125 |
| test.cpp:122:9:122:9 | b | test.cpp:122:9:122:9 | b | != | 0 | 125 | 125 |
| test.cpp:125:13:125:20 | ! ... | test.cpp:125:13:125:20 | ! ... | != | 0 | 125 | 125 |
| test.cpp:125:14:125:17 | call to safe | test.cpp:125:14:125:17 | call to safe | == | 0 | 125 | 125 |
14 changes: 14 additions & 0 deletions cpp/ql/test/library-tests/controlflow/guards/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,17 @@ void int_float_comparison(int i) {
use(i);
}
}

int source();
bool safe(int);

void test(bool b)
{
int x;
if (b)
{
x = source();
if (!safe(x)) return;
}
use(x);
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void test_guard_and_reassign() {
if(!guarded(x)) {
x = 0;
}
sink(x); // $ SPURIOUS: ast,ir
sink(x); // $ SPURIOUS: ast
}

void test_phi_read_guard(bool b) {
Expand All @@ -98,7 +98,7 @@ void test_phi_read_guard(bool b) {
return;
}

sink(x); // $ SPURIOUS: ast,ir
sink(x); // $ SPURIOUS: ast
}

bool unsafe(int);
Expand Down
35 changes: 19 additions & 16 deletions cpp/ql/test/library-tests/dataflow/dataflow-tests/TestBase.qll
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@ module AstTest {
* S in `if (guarded(x)) S`.
*/
// This is tested in `BarrierGuard.cpp`.
predicate testBarrierGuard(GuardCondition g, Expr checked, boolean isTrue) {
g.(FunctionCall).getTarget().getName() = "guarded" and
checked = g.(FunctionCall).getArgument(0) and
isTrue = true
or
g.(FunctionCall).getTarget().getName() = "unsafe" and
checked = g.(FunctionCall).getArgument(0) and
isTrue = false
predicate testBarrierGuard(GuardCondition g, Expr checked, boolean branch) {
exists(Call call, boolean b |
checked = call.getArgument(0) and
g.comparesEq(call, 0, b, any(BooleanValue bv | bv.getValue() = branch))
|
call.getTarget().hasName("guarded") and
b = false
or
call.getTarget().hasName("unsafe") and
b = true
)
}

/** Common data flow configuration to be used by tests. */
Expand Down Expand Up @@ -106,16 +109,16 @@ module IRTest {
* S in `if (guarded(x)) S`.
*/
// This is tested in `BarrierGuard.cpp`.
predicate testBarrierGuard(IRGuardCondition g, Expr checked, boolean isTrue) {
exists(Call call |
call = g.getUnconvertedResultExpression() and
checked = call.getArgument(0)
predicate testBarrierGuard(IRGuardCondition g, Expr checked, boolean branch) {
exists(CallInstruction call, boolean b |
checked = call.getArgument(0).getUnconvertedResultExpression() and
g.comparesEq(call.getAUse(), 0, b, any(BooleanValue bv | bv.getValue() = branch))
|
call.getTarget().hasName("guarded") and
isTrue = true
call.getStaticCallTarget().hasName("guarded") and
b = false
or
call.getTarget().hasName("unsafe") and
isTrue = false
call.getStaticCallTarget().hasName("unsafe") and
b = true
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,6 @@ irFlow
| BarrierGuard.cpp:49:10:49:15 | call to source | BarrierGuard.cpp:55:13:55:13 | x |
| BarrierGuard.cpp:60:11:60:16 | call to source | BarrierGuard.cpp:64:14:64:14 | x |
| BarrierGuard.cpp:60:11:60:16 | call to source | BarrierGuard.cpp:66:14:66:14 | x |
| BarrierGuard.cpp:81:11:81:16 | call to source | BarrierGuard.cpp:86:8:86:8 | x |
| BarrierGuard.cpp:90:11:90:16 | call to source | BarrierGuard.cpp:101:8:101:8 | x |
| acrossLinkTargets.cpp:19:27:19:32 | call to source | acrossLinkTargets.cpp:12:8:12:8 | x |
| clang.cpp:12:9:12:20 | sourceArray1 | clang.cpp:18:8:18:19 | sourceArray1 |
| clang.cpp:12:9:12:20 | sourceArray1 | clang.cpp:23:17:23:29 | *& ... |
Expand Down

0 comments on commit 000a81f

Please sign in to comment.