Skip to content

Rust: Model async return types as dyn Future #20236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,22 @@ abstract class Type extends TType {
pragma[nomagic]
abstract TupleField getTupleField(int i);

/** Gets the `i`th type parameter of this type, if any. */
abstract TypeParameter getTypeParameter(int i);
/**
* Gets the `i`th positional type parameter of this type, if any.
*
* This excludes for example associated type parameters.
*/
abstract TypeParameter getPositionalTypeParameter(int i);

/** Gets the default type for the `i`th type parameter, if any. */
TypeMention getTypeParameterDefault(int i) { none() }

/** Gets a type parameter of this type. */
final TypeParameter getATypeParameter() { result = this.getTypeParameter(_) }
/**
* Gets a type parameter of this type.
*
* This includes both positional and other type parameters, such as associated types.
*/
TypeParameter getATypeParameter() { result = this.getPositionalTypeParameter(_) }

/** Gets a textual representation of this type. */
abstract string toString();
Expand All @@ -108,7 +116,9 @@ class TupleType extends Type, TTuple {

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) { result = TTupleTypeParameter(arity, i) }
override TypeParameter getPositionalTypeParameter(int i) {
result = TTupleTypeParameter(arity, i)
}

/** Gets the arity of this tuple type. */
int getArity() { result = arity }
Expand Down Expand Up @@ -141,7 +151,7 @@ class StructType extends StructOrEnumType, TStruct {

override TupleField getTupleField(int i) { result = struct.getTupleField(i) }

override TypeParameter getTypeParameter(int i) {
override TypeParameter getPositionalTypeParameter(int i) {
result = TTypeParamTypeParameter(struct.getGenericParamList().getTypeParam(i))
}

Expand All @@ -166,7 +176,7 @@ class EnumType extends StructOrEnumType, TEnum {

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) {
override TypeParameter getPositionalTypeParameter(int i) {
result = TTypeParamTypeParameter(enum.getGenericParamList().getTypeParam(i))
}

Expand All @@ -192,10 +202,18 @@ class TraitType extends Type, TTrait {

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) {
override TypeParameter getPositionalTypeParameter(int i) {
result = TTypeParamTypeParameter(trait.getGenericParamList().getTypeParam(i))
}

override TypeParameter getATypeParameter() {
result = super.getATypeParameter()
or
result.(AssociatedTypeTypeParameter).getTrait() = trait
or
result.(SelfTypeParameter).getTrait() = trait
}

override TypeMention getTypeParameterDefault(int i) {
result = trait.getGenericParamList().getTypeParam(i).getDefaultType()
}
Expand All @@ -218,7 +236,7 @@ class ArrayType extends Type, TArrayType {

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) {
override TypeParameter getPositionalTypeParameter(int i) {
result = TArrayTypeParameter() and
i = 0
}
Expand All @@ -241,7 +259,7 @@ class RefType extends Type, TRefType {

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) {
override TypeParameter getPositionalTypeParameter(int i) {
result = TRefTypeParameter() and
i = 0
}
Expand Down Expand Up @@ -274,7 +292,7 @@ class ImplTraitType extends Type, TImplTraitType {

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) {
override TypeParameter getPositionalTypeParameter(int i) {
exists(TypeParam tp |
implTraitTypeParam(impl, i, tp) and
result = TImplTraitTypeParameter(impl, tp)
Expand All @@ -295,10 +313,19 @@ class DynTraitType extends Type, TDynTraitType {

override TupleField getTupleField(int i) { none() }

override DynTraitTypeParameter getTypeParameter(int i) {
override DynTraitTypeParameter getPositionalTypeParameter(int i) {
result = TDynTraitTypeParameter(trait.getGenericParamList().getTypeParam(i))
}

override TypeParameter getATypeParameter() {
result = super.getATypeParameter()
or
exists(AstNode n |
dynTraitTypeParameter(trait, n) and
result = TDynTraitTypeParameter(n)
)
}

Trait getTrait() { result = trait }

override string toString() { result = "dyn " + trait.getName().toString() }
Expand Down Expand Up @@ -336,7 +363,7 @@ class SliceType extends Type, TSliceType {

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) {
override TypeParameter getPositionalTypeParameter(int i) {
result = TSliceTypeParameter() and
i = 0
}
Expand All @@ -352,7 +379,7 @@ abstract class TypeParameter extends Type {

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) { none() }
override TypeParameter getPositionalTypeParameter(int i) { none() }
}

private class RawTypeParameter = @type_param or @trait or @type_alias or @impl_trait_type_repr;
Expand Down Expand Up @@ -548,7 +575,7 @@ class ImplTraitTypeTypeParameter extends ImplTraitType, TypeParameter {

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) { none() }
override TypeParameter getPositionalTypeParameter(int i) { none() }
}

/**
Expand Down
23 changes: 15 additions & 8 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ private module CertainTypeInference {
or
result = inferLiteralType(n, path, true)
or
result = inferAsyncBlockExprRootType(n) and
path.isEmpty()
or
infersCertainTypeAt(n, path, result.getATypeParameter())
}

Expand Down Expand Up @@ -521,7 +524,7 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat
n2 = be.getStmtList().getTailExpr() and
if be.isAsync()
then
prefix1 = TypePath::singleton(getFutureOutputTypeParameter()) and
prefix1 = TypePath::singleton(getDynFutureOutputTypeParameter()) and
prefix2.isEmpty()
else (
prefix1.isEmpty() and
Expand Down Expand Up @@ -941,7 +944,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig {
or
exists(TypePath suffix |
result = this.resolveRetType(suffix) and
path = TypePath::cons(getFutureOutputTypeParameter(), suffix)
path = TypePath::cons(getDynFutureOutputTypeParameter(), suffix)
)
else result = this.resolveRetType(path)
}
Expand Down Expand Up @@ -1427,16 +1430,22 @@ private Type inferLiteralType(LiteralExpr le, TypePath path, boolean certain) {
certain = true
}

// always exists because of the mention in `builtins/mentions.rs`
pragma[nomagic]
private TraitType getFutureTraitType() { result.getTrait() instanceof FutureTrait }
private DynTraitType getFutureTraitType() { result.getTrait() instanceof FutureTrait }

pragma[nomagic]
private AssociatedTypeTypeParameter getFutureOutputTypeParameter() {
result.getTypeAlias() = any(FutureTrait ft).getOutputType()
}

pragma[nomagic]
private TraitType inferAsyncBlockExprRootType(AsyncBlockExpr abe) {
private DynTraitTypeParameter getDynFutureOutputTypeParameter() {
result = TDynTraitTypeParameter(any(FutureTrait ft).getOutputType())
}

pragma[nomagic]
private DynTraitType inferAsyncBlockExprRootType(AsyncBlockExpr abe) {
// `typeEquality` handles the non-root case
exists(abe) and
result = getFutureTraitType()
Expand All @@ -1449,6 +1458,7 @@ final private class AwaitTarget extends Expr {
}

private module AwaitSatisfiesConstraintInput implements SatisfiesConstraintInputSig<AwaitTarget> {
pragma[nomagic]
predicate relevantConstraint(AwaitTarget term, Type constraint) {
exists(term) and
constraint.(TraitType).getTrait() instanceof FutureTrait
Expand Down Expand Up @@ -1774,7 +1784,7 @@ private Type inferClosureExprType(AstNode n, TypePath path) {
exists(ClosureExpr ce |
n = ce and
path.isEmpty() and
result = TDynTraitType(any(FnOnceTrait t))
result = TDynTraitType(any(FnOnceTrait t)) // always exists because of the mention in `builtins/mentions.rs`
or
n = ce and
path = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam())) and
Expand Down Expand Up @@ -2382,9 +2392,6 @@ private module Cached {
or
result = inferLiteralType(n, path, false)
or
result = inferAsyncBlockExprRootType(n) and
path.isEmpty()
or
result = inferAwaitExprType(n, path)
or
result = inferArrayExprType(n) and
Expand Down
2 changes: 1 addition & 1 deletion rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class NonAliasPathTypeMention extends PathTypeMention {
private TypeMention getTypeMentionForTypeParameter(TypeParameter tp) {
exists(int i |
result = this.getPositionalTypeArgument(pragma[only_bind_into](i)) and
tp = this.resolveRootType().getTypeParameter(pragma[only_bind_into](i))
tp = this.resolveRootType().getPositionalTypeParameter(pragma[only_bind_into](i))
)
or
exists(TypeAlias alias |
Expand Down
23 changes: 11 additions & 12 deletions rust/ql/test/library-tests/type-inference/type-inference.expected
Original file line number Diff line number Diff line change
Expand Up @@ -3728,9 +3728,8 @@ inferType
| main.rs:1915:25:1917:5 | { ... } | | main.rs:1909:5:1909:14 | S1 |
| main.rs:1916:9:1916:10 | S1 | | main.rs:1909:5:1909:14 | S1 |
| main.rs:1919:41:1921:5 | { ... } | | main.rs:1919:16:1919:39 | ImplTraitTypeRepr |
| main.rs:1920:9:1920:20 | { ... } | | {EXTERNAL LOCATION} | trait Future |
| main.rs:1920:9:1920:20 | { ... } | | main.rs:1919:16:1919:39 | ImplTraitTypeRepr |
| main.rs:1920:9:1920:20 | { ... } | Output | main.rs:1909:5:1909:14 | S1 |
| main.rs:1920:9:1920:20 | { ... } | | {EXTERNAL LOCATION} | dyn Future |
| main.rs:1920:9:1920:20 | { ... } | dyn(Output) | main.rs:1909:5:1909:14 | S1 |
| main.rs:1920:17:1920:18 | S1 | | main.rs:1909:5:1909:14 | S1 |
| main.rs:1929:13:1929:42 | SelfParam | | {EXTERNAL LOCATION} | Pin |
| main.rs:1929:13:1929:42 | SelfParam | Ptr | file://:0:0:0:0 | & |
Expand All @@ -3745,22 +3744,22 @@ inferType
| main.rs:1936:41:1938:5 | { ... } | | main.rs:1936:16:1936:39 | ImplTraitTypeRepr |
| main.rs:1937:9:1937:10 | S2 | | main.rs:1923:5:1923:14 | S2 |
| main.rs:1937:9:1937:10 | S2 | | main.rs:1936:16:1936:39 | ImplTraitTypeRepr |
| main.rs:1941:9:1941:12 | f1(...) | | {EXTERNAL LOCATION} | trait Future |
| main.rs:1941:9:1941:12 | f1(...) | Output | main.rs:1909:5:1909:14 | S1 |
| main.rs:1941:9:1941:12 | f1(...) | | {EXTERNAL LOCATION} | dyn Future |
| main.rs:1941:9:1941:12 | f1(...) | dyn(Output) | main.rs:1909:5:1909:14 | S1 |
| main.rs:1941:9:1941:18 | await ... | | main.rs:1909:5:1909:14 | S1 |
| main.rs:1942:9:1942:12 | f2(...) | | main.rs:1919:16:1919:39 | ImplTraitTypeRepr |
| main.rs:1942:9:1942:18 | await ... | | main.rs:1909:5:1909:14 | S1 |
| main.rs:1943:9:1943:12 | f3(...) | | main.rs:1936:16:1936:39 | ImplTraitTypeRepr |
| main.rs:1943:9:1943:18 | await ... | | main.rs:1909:5:1909:14 | S1 |
| main.rs:1944:9:1944:10 | S2 | | main.rs:1923:5:1923:14 | S2 |
| main.rs:1944:9:1944:16 | await S2 | | main.rs:1909:5:1909:14 | S1 |
| main.rs:1945:13:1945:13 | b | | {EXTERNAL LOCATION} | trait Future |
| main.rs:1945:13:1945:13 | b | Output | main.rs:1909:5:1909:14 | S1 |
| main.rs:1945:17:1945:28 | { ... } | | {EXTERNAL LOCATION} | trait Future |
| main.rs:1945:17:1945:28 | { ... } | Output | main.rs:1909:5:1909:14 | S1 |
| main.rs:1945:13:1945:13 | b | | {EXTERNAL LOCATION} | dyn Future |
| main.rs:1945:13:1945:13 | b | dyn(Output) | main.rs:1909:5:1909:14 | S1 |
| main.rs:1945:17:1945:28 | { ... } | | {EXTERNAL LOCATION} | dyn Future |
| main.rs:1945:17:1945:28 | { ... } | dyn(Output) | main.rs:1909:5:1909:14 | S1 |
| main.rs:1945:25:1945:26 | S1 | | main.rs:1909:5:1909:14 | S1 |
| main.rs:1946:9:1946:9 | b | | {EXTERNAL LOCATION} | trait Future |
| main.rs:1946:9:1946:9 | b | Output | main.rs:1909:5:1909:14 | S1 |
| main.rs:1946:9:1946:9 | b | | {EXTERNAL LOCATION} | dyn Future |
| main.rs:1946:9:1946:9 | b | dyn(Output) | main.rs:1909:5:1909:14 | S1 |
| main.rs:1946:9:1946:15 | await b | | main.rs:1909:5:1909:14 | S1 |
| main.rs:1957:15:1957:19 | SelfParam | | file://:0:0:0:0 | & |
| main.rs:1957:15:1957:19 | SelfParam | &T | main.rs:1956:5:1958:5 | Self [trait Trait1] |
Expand Down Expand Up @@ -4992,7 +4991,7 @@ inferType
| main.rs:2613:5:2613:60 | ...::g(...) | | main.rs:72:5:72:21 | Foo |
| main.rs:2613:20:2613:38 | ...::Foo {...} | | main.rs:72:5:72:21 | Foo |
| main.rs:2613:41:2613:59 | ...::Foo {...} | | main.rs:72:5:72:21 | Foo |
| main.rs:2629:5:2629:15 | ...::f(...) | | {EXTERNAL LOCATION} | trait Future |
| main.rs:2629:5:2629:15 | ...::f(...) | | {EXTERNAL LOCATION} | dyn Future |
| pattern_matching.rs:13:26:133:1 | { ... } | | {EXTERNAL LOCATION} | Option |
| pattern_matching.rs:13:26:133:1 | { ... } | T | file://:0:0:0:0 | () |
| pattern_matching.rs:14:9:14:13 | value | | {EXTERNAL LOCATION} | Option |
Expand Down
6 changes: 6 additions & 0 deletions rust/tools/builtins/mentions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Type mentions required by type inference

use std::future::Future;
fn mention_dyn_future<T>(f: &dyn Future<Output = T>) {}

Check notice

Code scanning / CodeQL

Unused variable Note

Variable 'f' is not used.

fn mention_dyn_fn_once<F>(f: &dyn FnOnce() -> F) {}

Check notice

Code scanning / CodeQL

Unused variable Note

Variable 'f' is not used.
Original file line number Diff line number Diff line change
Expand Up @@ -1007,9 +1007,6 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
tt.getTypeAt(pathToTypeParamInSub.appendInverse(suffix)) = t and
path = prefix0.append(suffix)
)
or
hasTypeConstraint(tt, constraint, constraint) and
t = tt.getTypeAt(path)
}
}

Expand Down
Loading