diff --git a/rust/ql/lib/codeql/rust/internal/Type.qll b/rust/ql/lib/codeql/rust/internal/Type.qll index 88eb50e09e38..2c5d37a3eb7d 100644 --- a/rust/ql/lib/codeql/rust/internal/Type.qll +++ b/rust/ql/lib/codeql/rust/internal/Type.qll @@ -82,14 +82,22 @@ abstract class Type extends TType { pragma[nomagic] abstract TupleField getTupleField(int i); - /** Gets the `i`th type parameter of this type, if any. */ - abstract TypeParameter getTypeParameter(int i); + /** + * Gets the `i`th positional type parameter of this type, if any. + * + * This excludes for example associated type parameters. + */ + abstract TypeParameter getPositionalTypeParameter(int i); /** Gets the default type for the `i`th type parameter, if any. */ TypeMention getTypeParameterDefault(int i) { none() } - /** Gets a type parameter of this type. */ - final TypeParameter getATypeParameter() { result = this.getTypeParameter(_) } + /** + * Gets a type parameter of this type. + * + * This includes both positional and other type parameters, such as associated types. + */ + TypeParameter getATypeParameter() { result = this.getPositionalTypeParameter(_) } /** Gets a textual representation of this type. */ abstract string toString(); @@ -108,7 +116,9 @@ class TupleType extends Type, TTuple { override TupleField getTupleField(int i) { none() } - override TypeParameter getTypeParameter(int i) { result = TTupleTypeParameter(arity, i) } + override TypeParameter getPositionalTypeParameter(int i) { + result = TTupleTypeParameter(arity, i) + } /** Gets the arity of this tuple type. */ int getArity() { result = arity } @@ -141,7 +151,7 @@ class StructType extends StructOrEnumType, TStruct { override TupleField getTupleField(int i) { result = struct.getTupleField(i) } - override TypeParameter getTypeParameter(int i) { + override TypeParameter getPositionalTypeParameter(int i) { result = TTypeParamTypeParameter(struct.getGenericParamList().getTypeParam(i)) } @@ -166,7 +176,7 @@ class EnumType extends StructOrEnumType, TEnum { override TupleField getTupleField(int i) { none() } - override TypeParameter getTypeParameter(int i) { + override TypeParameter getPositionalTypeParameter(int i) { result = TTypeParamTypeParameter(enum.getGenericParamList().getTypeParam(i)) } @@ -192,10 +202,18 @@ class TraitType extends Type, TTrait { override TupleField getTupleField(int i) { none() } - override TypeParameter getTypeParameter(int i) { + override TypeParameter getPositionalTypeParameter(int i) { result = TTypeParamTypeParameter(trait.getGenericParamList().getTypeParam(i)) } + override TypeParameter getATypeParameter() { + result = super.getATypeParameter() + or + result.(AssociatedTypeTypeParameter).getTrait() = trait + or + result.(SelfTypeParameter).getTrait() = trait + } + override TypeMention getTypeParameterDefault(int i) { result = trait.getGenericParamList().getTypeParam(i).getDefaultType() } @@ -218,7 +236,7 @@ class ArrayType extends Type, TArrayType { override TupleField getTupleField(int i) { none() } - override TypeParameter getTypeParameter(int i) { + override TypeParameter getPositionalTypeParameter(int i) { result = TArrayTypeParameter() and i = 0 } @@ -241,7 +259,7 @@ class RefType extends Type, TRefType { override TupleField getTupleField(int i) { none() } - override TypeParameter getTypeParameter(int i) { + override TypeParameter getPositionalTypeParameter(int i) { result = TRefTypeParameter() and i = 0 } @@ -274,7 +292,7 @@ class ImplTraitType extends Type, TImplTraitType { override TupleField getTupleField(int i) { none() } - override TypeParameter getTypeParameter(int i) { + override TypeParameter getPositionalTypeParameter(int i) { exists(TypeParam tp | implTraitTypeParam(impl, i, tp) and result = TImplTraitTypeParameter(impl, tp) @@ -295,10 +313,19 @@ class DynTraitType extends Type, TDynTraitType { override TupleField getTupleField(int i) { none() } - override DynTraitTypeParameter getTypeParameter(int i) { + override DynTraitTypeParameter getPositionalTypeParameter(int i) { result = TDynTraitTypeParameter(trait.getGenericParamList().getTypeParam(i)) } + override TypeParameter getATypeParameter() { + result = super.getATypeParameter() + or + exists(AstNode n | + dynTraitTypeParameter(trait, n) and + result = TDynTraitTypeParameter(n) + ) + } + Trait getTrait() { result = trait } override string toString() { result = "dyn " + trait.getName().toString() } @@ -336,7 +363,7 @@ class SliceType extends Type, TSliceType { override TupleField getTupleField(int i) { none() } - override TypeParameter getTypeParameter(int i) { + override TypeParameter getPositionalTypeParameter(int i) { result = TSliceTypeParameter() and i = 0 } @@ -352,7 +379,7 @@ abstract class TypeParameter extends Type { override TupleField getTupleField(int i) { none() } - override TypeParameter getTypeParameter(int i) { none() } + override TypeParameter getPositionalTypeParameter(int i) { none() } } private class RawTypeParameter = @type_param or @trait or @type_alias or @impl_trait_type_repr; @@ -548,7 +575,7 @@ class ImplTraitTypeTypeParameter extends ImplTraitType, TypeParameter { override TupleField getTupleField(int i) { none() } - override TypeParameter getTypeParameter(int i) { none() } + override TypeParameter getPositionalTypeParameter(int i) { none() } } /** diff --git a/rust/ql/lib/codeql/rust/internal/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/TypeInference.qll index d9a5bef9a653..1894ebf12ba1 100644 --- a/rust/ql/lib/codeql/rust/internal/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/TypeInference.qll @@ -374,6 +374,9 @@ private module CertainTypeInference { or result = inferLiteralType(n, path, true) or + result = inferAsyncBlockExprRootType(n) and + path.isEmpty() + or infersCertainTypeAt(n, path, result.getATypeParameter()) } @@ -521,7 +524,7 @@ private predicate typeEquality(AstNode n1, TypePath prefix1, AstNode n2, TypePat n2 = be.getStmtList().getTailExpr() and if be.isAsync() then - prefix1 = TypePath::singleton(getFutureOutputTypeParameter()) and + prefix1 = TypePath::singleton(getDynFutureOutputTypeParameter()) and prefix2.isEmpty() else ( prefix1.isEmpty() and @@ -941,7 +944,7 @@ private module CallExprBaseMatchingInput implements MatchingInputSig { or exists(TypePath suffix | result = this.resolveRetType(suffix) and - path = TypePath::cons(getFutureOutputTypeParameter(), suffix) + path = TypePath::cons(getDynFutureOutputTypeParameter(), suffix) ) else result = this.resolveRetType(path) } @@ -1427,8 +1430,9 @@ private Type inferLiteralType(LiteralExpr le, TypePath path, boolean certain) { certain = true } +// always exists because of the mention in `builtins/mentions.rs` pragma[nomagic] -private TraitType getFutureTraitType() { result.getTrait() instanceof FutureTrait } +private DynTraitType getFutureTraitType() { result.getTrait() instanceof FutureTrait } pragma[nomagic] private AssociatedTypeTypeParameter getFutureOutputTypeParameter() { @@ -1436,7 +1440,12 @@ private AssociatedTypeTypeParameter getFutureOutputTypeParameter() { } pragma[nomagic] -private TraitType inferAsyncBlockExprRootType(AsyncBlockExpr abe) { +private DynTraitTypeParameter getDynFutureOutputTypeParameter() { + result = TDynTraitTypeParameter(any(FutureTrait ft).getOutputType()) +} + +pragma[nomagic] +private DynTraitType inferAsyncBlockExprRootType(AsyncBlockExpr abe) { // `typeEquality` handles the non-root case exists(abe) and result = getFutureTraitType() @@ -1449,6 +1458,7 @@ final private class AwaitTarget extends Expr { } private module AwaitSatisfiesConstraintInput implements SatisfiesConstraintInputSig { + pragma[nomagic] predicate relevantConstraint(AwaitTarget term, Type constraint) { exists(term) and constraint.(TraitType).getTrait() instanceof FutureTrait @@ -1774,7 +1784,7 @@ private Type inferClosureExprType(AstNode n, TypePath path) { exists(ClosureExpr ce | n = ce and path.isEmpty() and - result = TDynTraitType(any(FnOnceTrait t)) + result = TDynTraitType(any(FnOnceTrait t)) // always exists because of the mention in `builtins/mentions.rs` or n = ce and path = TypePath::singleton(TDynTraitTypeParameter(any(FnOnceTrait t).getTypeParam())) and @@ -2382,9 +2392,6 @@ private module Cached { or result = inferLiteralType(n, path, false) or - result = inferAsyncBlockExprRootType(n) and - path.isEmpty() - or result = inferAwaitExprType(n, path) or result = inferArrayExprType(n) and diff --git a/rust/ql/lib/codeql/rust/internal/TypeMention.qll b/rust/ql/lib/codeql/rust/internal/TypeMention.qll index f7c5f2f25e0e..c36e19842377 100644 --- a/rust/ql/lib/codeql/rust/internal/TypeMention.qll +++ b/rust/ql/lib/codeql/rust/internal/TypeMention.qll @@ -182,7 +182,7 @@ class NonAliasPathTypeMention extends PathTypeMention { private TypeMention getTypeMentionForTypeParameter(TypeParameter tp) { exists(int i | result = this.getPositionalTypeArgument(pragma[only_bind_into](i)) and - tp = this.resolveRootType().getTypeParameter(pragma[only_bind_into](i)) + tp = this.resolveRootType().getPositionalTypeParameter(pragma[only_bind_into](i)) ) or exists(TypeAlias alias | diff --git a/rust/ql/test/library-tests/type-inference/type-inference.expected b/rust/ql/test/library-tests/type-inference/type-inference.expected index 1fb5a612918f..a6c73665243e 100644 --- a/rust/ql/test/library-tests/type-inference/type-inference.expected +++ b/rust/ql/test/library-tests/type-inference/type-inference.expected @@ -3728,9 +3728,8 @@ inferType | main.rs:1915:25:1917:5 | { ... } | | main.rs:1909:5:1909:14 | S1 | | main.rs:1916:9:1916:10 | S1 | | main.rs:1909:5:1909:14 | S1 | | main.rs:1919:41:1921:5 | { ... } | | main.rs:1919:16:1919:39 | ImplTraitTypeRepr | -| main.rs:1920:9:1920:20 | { ... } | | {EXTERNAL LOCATION} | trait Future | -| main.rs:1920:9:1920:20 | { ... } | | main.rs:1919:16:1919:39 | ImplTraitTypeRepr | -| main.rs:1920:9:1920:20 | { ... } | Output | main.rs:1909:5:1909:14 | S1 | +| main.rs:1920:9:1920:20 | { ... } | | {EXTERNAL LOCATION} | dyn Future | +| main.rs:1920:9:1920:20 | { ... } | dyn(Output) | main.rs:1909:5:1909:14 | S1 | | main.rs:1920:17:1920:18 | S1 | | main.rs:1909:5:1909:14 | S1 | | main.rs:1929:13:1929:42 | SelfParam | | {EXTERNAL LOCATION} | Pin | | main.rs:1929:13:1929:42 | SelfParam | Ptr | file://:0:0:0:0 | & | @@ -3745,8 +3744,8 @@ inferType | main.rs:1936:41:1938:5 | { ... } | | main.rs:1936:16:1936:39 | ImplTraitTypeRepr | | main.rs:1937:9:1937:10 | S2 | | main.rs:1923:5:1923:14 | S2 | | main.rs:1937:9:1937:10 | S2 | | main.rs:1936:16:1936:39 | ImplTraitTypeRepr | -| main.rs:1941:9:1941:12 | f1(...) | | {EXTERNAL LOCATION} | trait Future | -| main.rs:1941:9:1941:12 | f1(...) | Output | main.rs:1909:5:1909:14 | S1 | +| main.rs:1941:9:1941:12 | f1(...) | | {EXTERNAL LOCATION} | dyn Future | +| main.rs:1941:9:1941:12 | f1(...) | dyn(Output) | main.rs:1909:5:1909:14 | S1 | | main.rs:1941:9:1941:18 | await ... | | main.rs:1909:5:1909:14 | S1 | | main.rs:1942:9:1942:12 | f2(...) | | main.rs:1919:16:1919:39 | ImplTraitTypeRepr | | main.rs:1942:9:1942:18 | await ... | | main.rs:1909:5:1909:14 | S1 | @@ -3754,13 +3753,13 @@ inferType | main.rs:1943:9:1943:18 | await ... | | main.rs:1909:5:1909:14 | S1 | | main.rs:1944:9:1944:10 | S2 | | main.rs:1923:5:1923:14 | S2 | | main.rs:1944:9:1944:16 | await S2 | | main.rs:1909:5:1909:14 | S1 | -| main.rs:1945:13:1945:13 | b | | {EXTERNAL LOCATION} | trait Future | -| main.rs:1945:13:1945:13 | b | Output | main.rs:1909:5:1909:14 | S1 | -| main.rs:1945:17:1945:28 | { ... } | | {EXTERNAL LOCATION} | trait Future | -| main.rs:1945:17:1945:28 | { ... } | Output | main.rs:1909:5:1909:14 | S1 | +| main.rs:1945:13:1945:13 | b | | {EXTERNAL LOCATION} | dyn Future | +| main.rs:1945:13:1945:13 | b | dyn(Output) | main.rs:1909:5:1909:14 | S1 | +| main.rs:1945:17:1945:28 | { ... } | | {EXTERNAL LOCATION} | dyn Future | +| main.rs:1945:17:1945:28 | { ... } | dyn(Output) | main.rs:1909:5:1909:14 | S1 | | main.rs:1945:25:1945:26 | S1 | | main.rs:1909:5:1909:14 | S1 | -| main.rs:1946:9:1946:9 | b | | {EXTERNAL LOCATION} | trait Future | -| main.rs:1946:9:1946:9 | b | Output | main.rs:1909:5:1909:14 | S1 | +| main.rs:1946:9:1946:9 | b | | {EXTERNAL LOCATION} | dyn Future | +| main.rs:1946:9:1946:9 | b | dyn(Output) | main.rs:1909:5:1909:14 | S1 | | main.rs:1946:9:1946:15 | await b | | main.rs:1909:5:1909:14 | S1 | | main.rs:1957:15:1957:19 | SelfParam | | file://:0:0:0:0 | & | | main.rs:1957:15:1957:19 | SelfParam | &T | main.rs:1956:5:1958:5 | Self [trait Trait1] | @@ -4992,7 +4991,7 @@ inferType | main.rs:2613:5:2613:60 | ...::g(...) | | main.rs:72:5:72:21 | Foo | | main.rs:2613:20:2613:38 | ...::Foo {...} | | main.rs:72:5:72:21 | Foo | | main.rs:2613:41:2613:59 | ...::Foo {...} | | main.rs:72:5:72:21 | Foo | -| main.rs:2629:5:2629:15 | ...::f(...) | | {EXTERNAL LOCATION} | trait Future | +| main.rs:2629:5:2629:15 | ...::f(...) | | {EXTERNAL LOCATION} | dyn Future | | pattern_matching.rs:13:26:133:1 | { ... } | | {EXTERNAL LOCATION} | Option | | pattern_matching.rs:13:26:133:1 | { ... } | T | file://:0:0:0:0 | () | | pattern_matching.rs:14:9:14:13 | value | | {EXTERNAL LOCATION} | Option | diff --git a/rust/tools/builtins/mentions.rs b/rust/tools/builtins/mentions.rs new file mode 100644 index 000000000000..a3731164893b --- /dev/null +++ b/rust/tools/builtins/mentions.rs @@ -0,0 +1,6 @@ +// Type mentions required by type inference + +use std::future::Future; +fn mention_dyn_future(f: &dyn Future) {} + +fn mention_dyn_fn_once(f: &dyn FnOnce() -> F) {} diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index c42a424f3e34..430add5ec005 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -1007,9 +1007,6 @@ module Make1 Input1> { tt.getTypeAt(pathToTypeParamInSub.appendInverse(suffix)) = t and path = prefix0.append(suffix) ) - or - hasTypeConstraint(tt, constraint, constraint) and - t = tt.getTypeAt(path) } }