Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 43 additions & 66 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -207,81 +207,58 @@ private Type inferAssignmentOperationType(AstNode n, TypePath path) {
}

/**
* Holds if the type of `n1` at `path1` is the same as the type of `n2` at
* `path2` and type information should propagate in both directions through the
* type equality.
* Holds if the type tree of `n1` at `prefix1` should be equal to the type tree
* of `n2` at `prefix2` and type information should propagate in both directions
* through the type equality.
*/
bindingset[path1]
bindingset[path2]
private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
exists(Variable v |
path1 = path2 and
n1 = v.getAnAccess()
|
n2 = v.getPat()
private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
prefix1.isEmpty() and
prefix2.isEmpty() and
(
exists(Variable v | n1 = v.getAnAccess() |
n2 = v.getPat()
or
n2 = v.getParameter().(SelfParam)
)
or
n2 = v.getParameter().(SelfParam)
)
or
exists(LetStmt let |
let.getPat() = n1 and
let.getInitializer() = n2 and
path1 = path2
)
or
n1 = n2.(ParenExpr).getExpr() and
path1 = path2
or
n1 = n2.(BlockExpr).getStmtList().getTailExpr() and
path1 = path2
or
n1 = n2.(IfExpr).getABranch() and
path1 = path2
or
n1 = n2.(MatchExpr).getAnArm().getExpr() and
path1 = path2
or
exists(BreakExpr break |
break.getExpr() = n1 and
break.getTarget() = n2.(LoopExpr) and
path1 = path2
)
or
exists(AssignmentExpr be |
n1 = be.getLhs() and
n2 = be.getRhs() and
path1 = path2
)
}

bindingset[path1]
private predicate typeEqualityLeft(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
typeEquality(n1, path1, n2, path2)
or
n2 =
any(DerefExpr pe |
pe.getExpr() = n1 and
path1.isCons(TRefTypeParameter(), path2)
exists(LetStmt let |
let.getPat() = n1 and
let.getInitializer() = n2
)
}

bindingset[path2]
private predicate typeEqualityRight(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
typeEquality(n1, path1, n2, path2)
or
n2 =
any(DerefExpr pe |
pe.getExpr() = n1 and
path1 = TypePath::cons(TRefTypeParameter(), path2)
or
n1 = n2.(ParenExpr).getExpr()
or
n1 = n2.(BlockExpr).getStmtList().getTailExpr()
or
n1 = n2.(IfExpr).getABranch()
or
n1 = n2.(MatchExpr).getAnArm().getExpr()
or
exists(BreakExpr break |
break.getExpr() = n1 and
break.getTarget() = n2.(LoopExpr)
)
or
exists(AssignmentExpr be |
n1 = be.getLhs() and
n2 = be.getRhs()
)
)
or
n1 = n2.(DerefExpr).getExpr() and
prefix1 = TypePath::singleton(TRefTypeParameter()) and
prefix2.isEmpty()
}

pragma[nomagic]
private Type inferTypeEquality(AstNode n, TypePath path) {
exists(AstNode n2, TypePath path2 | result = inferType(n2, path2) |
typeEqualityRight(n, path, n2, path2)
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
result = inferType(n2, prefix2.appendInverse(suffix)) and
path = prefix1.append(suffix)
|
typeEquality(n, prefix1, n2, prefix2)
or
typeEqualityLeft(n2, path2, n, path)
typeEquality(n2, prefix2, n, prefix1)
)
}

Expand Down