Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,6 @@ private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath
path1 = path2
)
or
n2 =
any(PrefixExpr pe |
pe.getOperatorName() = "*" and
pe.getExpr() = n1 and
path1 = TypePath::cons(TRefTypeParameter(), path2)
)
or
n1 = n2.(ParenExpr).getExpr() and
path1 = path2
or
Expand All @@ -239,12 +232,36 @@ private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath
)
}

bindingset[path1]
private predicate typeEqualityLeft(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
typeEquality(n1, path1, n2, path2)
or
n2 =
any(PrefixExpr pe |
pe.getOperatorName() = "*" and
pe.getExpr() = n1 and
path1.isCons(TRefTypeParameter(), path2)
)
}

bindingset[path2]
private predicate typeEqualityRight(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
typeEquality(n1, path1, n2, path2)
or
n2 =
any(PrefixExpr pe |
pe.getOperatorName() = "*" and
pe.getExpr() = n1 and
path1 = TypePath::cons(TRefTypeParameter(), path2)
)
}

pragma[nomagic]
private Type inferTypeEquality(AstNode n, TypePath path) {
exists(AstNode n2, TypePath path2 | result = inferType(n2, path2) |
typeEquality(n, path, n2, path2)
typeEqualityRight(n, path, n2, path2)
or
typeEquality(n2, path2, n, path)
typeEqualityLeft(n2, path2, n, path)
)
}

Expand Down Expand Up @@ -909,7 +926,7 @@ private Type inferRefExprType(Expr e, TypePath path) {
e = re.getExpr() and
exists(TypePath exprPath, TypePath refPath, Type exprType |
result = inferType(re, exprPath) and
exprPath = TypePath::cons(TRefTypeParameter(), refPath) and
exprPath.isCons(TRefTypeParameter(), refPath) and
exprType = inferType(e)
|
if exprType = TRefType()
Expand All @@ -923,8 +940,9 @@ private Type inferRefExprType(Expr e, TypePath path) {

pragma[nomagic]
private Type inferTryExprType(TryExpr te, TypePath path) {
exists(TypeParam tp |
result = inferType(te.getExpr(), TypePath::cons(TTypeParamTypeParameter(tp), path))
exists(TypeParam tp, TypePath path0 |
result = inferType(te.getExpr(), path0) and
path0.isCons(TTypeParamTypeParameter(tp), path)
|
tp = any(ResultEnum r).getGenericParamList().getGenericParam(0)
or
Expand Down Expand Up @@ -1000,7 +1018,7 @@ private module Cached {
pragma[nomagic]
Type getTypeAt(TypePath path) {
exists(TypePath path0 | result = inferType(this, path0) |
path0 = TypePath::cons(TRefTypeParameter(), path)
path0.isCons(TRefTypeParameter(), path)
or
not path0.isCons(TRefTypeParameter(), _) and
not (path0.isEmpty() and result = TRefType()) and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,18 +181,29 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
/** Holds if this type path is empty. */
predicate isEmpty() { this = "" }

/** Gets the length of this path, assuming the length is at least 2. */
bindingset[this]
pragma[inline_late]
private int lengthAtLeast2() {
// Same as
// `result = strictcount(this.indexOf(".")) + 1`
// but performs better because it doesn't use an aggregate
result = this.regexpReplaceAll("[0-9]+", "").length() + 1
}

/** Gets the length of this path. */
bindingset[this]
pragma[inline_late]
int length() {
this.isEmpty() and result = 0
or
result = strictcount(this.indexOf(".")) + 1
if this.isEmpty()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps depth would actually be a better name than length?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't depth lead to the misconception that it is a tree instead of a list?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe? To me it feel natural to say that the depth of foo/bar/baz is 3, but let's just keep it as-if if you feel it's not as clear.

then result = 0
else
if exists(TypeParameter::decode(this))
then result = 1
else result = this.lengthAtLeast2()
}

/** Gets the path obtained by appending `suffix` onto this path. */
bindingset[suffix, result]
bindingset[this, result]
bindingset[this, suffix]
TypePath append(TypePath suffix) {
if this.isEmpty()
Expand All @@ -202,21 +213,40 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
then result = this
else (
result = this + "." + suffix and
not result.length() > getTypePathLimit()
(
not exists(getTypePathLimit())
or
result.lengthAtLeast2() <= getTypePathLimit()
)
)
}

/**
* Gets the path obtained by appending `suffix` onto this path.
*
* Unlike `append`, this predicate has `result` in the binding set,
* so there is no need to check the length of `result`.
Comment on lines +225 to +228
Copy link

Copilot AI May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation for appendInverse is copied from append but this predicate actually deconstructs a full path into this and suffix. Please update the comment to describe the inverse operation.

Suggested change
* Gets the path obtained by appending `suffix` onto this path.
*
* Unlike `append`, this predicate has `result` in the binding set,
* so there is no need to check the length of `result`.
* Deconstructs a full path `result` into `this` and `suffix`.
*
* This predicate performs the inverse operation of `append`. It holds if
* `result` is a path that can be split into `this` as the prefix and
* `suffix` as the remainder. For example, if `result` is "a.b.c" and
* `this` is "a.b", then `suffix` would be "c".

Copilot uses AI. Check for mistakes.
*/
bindingset[this, result]
TypePath appendInverse(TypePath suffix) { suffix = result.stripPrefix(this) }

/** Gets the path obtained by removing `prefix` from this path. */
bindingset[this, prefix]
TypePath stripPrefix(TypePath prefix) {
if prefix.isEmpty()
then result = this
else (
this = prefix and
result.isEmpty()
or
this = prefix + "." + result
)
}

/** Holds if this path starts with `tp`, followed by `suffix`. */
bindingset[this]
predicate isCons(TypeParameter tp, TypePath suffix) {
tp = TypeParameter::decode(this) and
suffix.isEmpty()
or
exists(int first |
first = min(this.indexOf(".")) and
suffix = this.suffix(first + 1) and
tp = TypeParameter::decode(this.prefix(first))
)
suffix = this.stripPrefix(TypePath::singleton(tp))
}
}

Expand All @@ -232,7 +262,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
* Gets the type path obtained by appending the singleton type path `tp`
* onto `suffix`.
*/
bindingset[result]
bindingset[suffix]
TypePath cons(TypeParameter tp, TypePath suffix) { result = singleton(tp).append(suffix) }
}
Expand Down Expand Up @@ -556,7 +585,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
TypeMention tm1, TypeMention tm2, TypeParameter tp, TypePath path, Type t
) {
exists(TypePath prefix |
tm2.resolveTypeAt(prefix) = tp and t = tm1.resolveTypeAt(prefix.append(path))
tm2.resolveTypeAt(prefix) = tp and t = tm1.resolveTypeAt(prefix.appendInverse(path))
)
}

Expand Down Expand Up @@ -899,7 +928,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
exists(AccessPosition apos, DeclarationPosition dpos, TypePath pathToTypeParam |
tp = target.getDeclaredType(dpos, pathToTypeParam) and
accessDeclarationPositionMatch(apos, dpos) and
adjustedAccessType(a, apos, target, pathToTypeParam.append(path), t)
adjustedAccessType(a, apos, target, pathToTypeParam.appendInverse(path), t)
)
}

Expand Down Expand Up @@ -998,7 +1027,9 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {

RelevantAccess() { this = MkRelevantAccess(a, apos, path) }

Type getTypeAt(TypePath suffix) { a.getInferredType(apos, path.append(suffix)) = result }
Type getTypeAt(TypePath suffix) {
a.getInferredType(apos, path.appendInverse(suffix)) = result
}

/** Holds if this relevant access has the type `type` and should satisfy `constraint`. */
predicate hasTypeConstraint(Type type, Type constraint) {
Expand Down Expand Up @@ -1077,7 +1108,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
t0 = abs.getATypeParameter() and
exists(TypePath path3, TypePath suffix |
sub.resolveTypeAt(path3) = t0 and
at.getTypeAt(path3.append(suffix)) = t and
at.getTypeAt(path3.appendInverse(suffix)) = t and
path = prefix0.append(suffix)
)
)
Expand Down Expand Up @@ -1149,7 +1180,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
not exists(getTypeArgument(a, target, tp, _)) and
target = a.getTarget() and
exists(AccessPosition apos, DeclarationPosition dpos, Type base, TypePath pathToTypeParam |
accessBaseType(a, apos, base, pathToTypeParam.append(path), t) and
accessBaseType(a, apos, base, pathToTypeParam.appendInverse(path), t) and
declarationBaseType(target, dpos, base, pathToTypeParam, tp) and
accessDeclarationPositionMatch(apos, dpos)
)
Expand Down Expand Up @@ -1217,7 +1248,7 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
typeParameterConstraintHasTypeParameter(target, dpos, pathToTp2, _, constraint, pathToTp,
tp) and
AccessConstraint::satisfiesConstraintTypeMention(a, apos, pathToTp2, constraint,
pathToTp.append(path), t)
pathToTp.appendInverse(path), t)
)
}

Expand Down
Loading