From dde845e92f20c850e09e4fb508eea1d39fe0081a Mon Sep 17 00:00:00 2001 From: Simon Friis Vindum Date: Thu, 18 Dec 2025 15:16:08 +0100 Subject: [PATCH 1/2] Rust: Refactor type inference to use new `TypeItem` class --- .../rust/frameworks/rustcrypto/RustCrypto.qll | 3 +- rust/ql/lib/codeql/rust/internal/Type.qll | 93 +++++-------- .../codeql/rust/internal/TypeInference.qll | 127 +++++++----------- .../lib/codeql/rust/internal/TypeMention.qll | 6 +- rust/ql/lib/codeql/rust/security/Barriers.qll | 8 +- 5 files changed, 88 insertions(+), 149 deletions(-) diff --git a/rust/ql/lib/codeql/rust/frameworks/rustcrypto/RustCrypto.qll b/rust/ql/lib/codeql/rust/frameworks/rustcrypto/RustCrypto.qll index 6e50659103d6..cbc638c8ae54 100644 --- a/rust/ql/lib/codeql/rust/frameworks/rustcrypto/RustCrypto.qll +++ b/rust/ql/lib/codeql/rust/frameworks/rustcrypto/RustCrypto.qll @@ -30,7 +30,8 @@ class StreamCipherInit extends Cryptography::CryptographicOperation::Range { // extract the algorithm name from the type of `ce` or its receiver. exists(Type t, TypePath tp | t = inferType([call, call.(MethodCall).getReceiver()], tp) and - rawAlgorithmName = t.(StructType).getStruct().(Addressable).getCanonicalPath().splitAt("::") + rawAlgorithmName = + t.(StructType).getTypeItem().(Addressable).getCanonicalPath().splitAt("::") ) and algorithmName = simplifyAlgorithmName(rawAlgorithmName) and // only match a known cryptographic algorithm diff --git a/rust/ql/lib/codeql/rust/internal/Type.qll b/rust/ql/lib/codeql/rust/internal/Type.qll index b4907dee172b..50dd99fb73ae 100644 --- a/rust/ql/lib/codeql/rust/internal/Type.qll +++ b/rust/ql/lib/codeql/rust/internal/Type.qll @@ -32,10 +32,8 @@ private predicate dynTraitTypeParameter(Trait trait, AstNode n) { cached newtype TType = - TStruct(Struct s) { Stages::TypeInferenceStage::ref() } or - TEnum(Enum e) or + TDataType(TypeItem ti) { Stages::TypeInferenceStage::ref() } or TTrait(Trait t) or - TUnion(Union u) or TImplTraitType(ImplTraitTypeRepr impl) or TDynTraitType(Trait t) { t = any(DynTraitTypeRepr dt).getTrait() } or TNeverType() or @@ -92,7 +90,7 @@ abstract class Type extends TType { class TupleType extends StructType { private int arity; - TupleType() { arity = this.getStruct().(Builtins::TupleType).getArity() } + TupleType() { arity = this.getTypeItem().(Builtins::TupleType).getArity() } /** Gets the arity of this tuple type. */ int getArity() { result = arity } @@ -112,48 +110,49 @@ class UnitType extends TupleType { override string toString() { result = "()" } } -/** A struct type. */ -class StructType extends Type, TStruct { - private Struct struct; +class DataType extends Type, TDataType { + private TypeItem typeItem; - StructType() { this = TStruct(struct) } + DataType() { this = TDataType(typeItem) } - /** Gets the struct that this struct type represents. */ - Struct getStruct() { result = struct } + /** Gets the type item that this data type represents. */ + TypeItem getTypeItem() { result = typeItem } override TypeParameter getPositionalTypeParameter(int i) { - result = TTypeParamTypeParameter(struct.getGenericParamList().getTypeParam(i)) + result = TTypeParamTypeParameter(typeItem.getGenericParamList().getTypeParam(i)) } override TypeMention getTypeParameterDefault(int i) { - result = struct.getGenericParamList().getTypeParam(i).getDefaultType() + result = typeItem.getGenericParamList().getTypeParam(i).getDefaultType() } - override string toString() { result = struct.getName().getText() } + override string toString() { result = typeItem.getName().getText() } - override Location getLocation() { result = struct.getLocation() } + override Location getLocation() { result = typeItem.getLocation() } } -/** An enum type. */ -class EnumType extends Type, TEnum { - private Enum enum; - - EnumType() { this = TEnum(enum) } +/** A struct type. */ +class StructType extends DataType { + StructType() { super.getTypeItem() instanceof Struct } - /** Gets the enum that this enum type represents. */ - Enum getEnum() { result = enum } + /** Gets the struct that this struct type represents. */ + override Struct getTypeItem() { result = super.getTypeItem() } +} - override TypeParameter getPositionalTypeParameter(int i) { - result = TTypeParamTypeParameter(enum.getGenericParamList().getTypeParam(i)) - } +/** An enum type. */ +class EnumType extends DataType { + EnumType() { super.getTypeItem() instanceof Enum } - override TypeMention getTypeParameterDefault(int i) { - result = enum.getGenericParamList().getTypeParam(i).getDefaultType() - } + /** Gets the enum that this enum type represents. */ + override Enum getTypeItem() { result = super.getTypeItem() } +} - override string toString() { result = enum.getName().getText() } +/** A union type. */ +class UnionType extends DataType { + UnionType() { super.getTypeItem() instanceof Union } - override Location getLocation() { result = enum.getLocation() } + /** Gets the union that this union type represents. */ + override Union getTypeItem() { result = super.getTypeItem() } } /** A trait type. */ @@ -186,35 +185,13 @@ class TraitType extends Type, TTrait { override Location getLocation() { result = trait.getLocation() } } -/** A union type. */ -class UnionType extends Type, TUnion { - private Union union; - - UnionType() { this = TUnion(union) } - - /** Gets the union that this union type represents. */ - Union getUnion() { result = union } - - override TypeParameter getPositionalTypeParameter(int i) { - result = TTypeParamTypeParameter(union.getGenericParamList().getTypeParam(i)) - } - - override TypeMention getTypeParameterDefault(int i) { - result = union.getGenericParamList().getTypeParam(i).getDefaultType() - } - - override string toString() { result = union.getName().getText() } - - override Location getLocation() { result = union.getLocation() } -} - /** * An array type. * * Array types like `[i64; 5]` are modeled as normal generic types. */ class ArrayType extends StructType { - ArrayType() { this.getStruct() instanceof Builtins::ArrayType } + ArrayType() { this.getTypeItem() instanceof Builtins::ArrayType } override string toString() { result = "[;]" } } @@ -227,13 +204,13 @@ TypeParamTypeParameter getArrayTypeParameter() { abstract class RefType extends StructType { } class RefMutType extends RefType { - RefMutType() { this.getStruct() instanceof Builtins::RefMutType } + RefMutType() { this.getTypeItem() instanceof Builtins::RefMutType } override string toString() { result = "&mut" } } class RefSharedType extends RefType { - RefSharedType() { this.getStruct() instanceof Builtins::RefSharedType } + RefSharedType() { this.getTypeItem() instanceof Builtins::RefSharedType } override string toString() { result = "&" } } @@ -330,7 +307,7 @@ class ImplTraitReturnType extends ImplTraitType { * with a single type argument. */ class SliceType extends StructType { - SliceType() { this.getStruct() instanceof Builtins::SliceType } + SliceType() { this.getTypeItem() instanceof Builtins::SliceType } override string toString() { result = "[]" } } @@ -356,13 +333,13 @@ TypeParamTypeParameter getPtrTypeParameter() { } class PtrMutType extends PtrType { - PtrMutType() { this.getStruct() instanceof Builtins::PtrMutType } + PtrMutType() { this.getTypeItem() instanceof Builtins::PtrMutType } override string toString() { result = "*mut" } } class PtrConstType extends PtrType { - PtrConstType() { this.getStruct() instanceof Builtins::PtrConstType } + PtrConstType() { this.getTypeItem() instanceof Builtins::PtrConstType } override string toString() { result = "*const" } } @@ -624,7 +601,7 @@ pragma[nomagic] predicate validSelfType(Type t) { t instanceof RefType or - exists(Struct s | t = TStruct(s) | + exists(Struct s | t = TDataType(s) | s instanceof BoxStruct or s instanceof RcStruct or s instanceof ArcStruct or diff --git a/rust/ql/lib/codeql/rust/internal/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/TypeInference.qll index 5b0ed6873574..b05e2921d3e6 100644 --- a/rust/ql/lib/codeql/rust/internal/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/TypeInference.qll @@ -619,7 +619,7 @@ private Type inferLogicalOperationType(AstNode n, TypePath path) { exists(Builtins::Bool t, BinaryLogicalOperation be | n = [be, be.getLhs(), be.getRhs()] and path.isEmpty() and - result = TStruct(t) + result = TDataType(t) ) } @@ -887,14 +887,14 @@ private module StructExprMatchingInput implements MatchingInputSig { } abstract class Declaration extends AstNode { - abstract TypeParam getATypeParam(); - final TypeParameter getTypeParameter(TypeParameterPosition ppos) { - typeParamMatchPosition(this.getATypeParam(), result, ppos) + typeParamMatchPosition(this.getTypeItem().getGenericParamList().getATypeParam(), result, ppos) } abstract StructField getField(string name); + abstract TypeItem getTypeItem(); + Type getDeclaredType(DeclarationPosition dpos, TypePath path) { // type of a field exists(TypeMention tp | @@ -906,45 +906,28 @@ private module StructExprMatchingInput implements MatchingInputSig { dpos.isStructPos() and result = this.getTypeParameter(_) and path = TypePath::singleton(result) + or + // type of the struct itself + dpos.isStructPos() and + path.isEmpty() and + result = TDataType(this.getTypeItem()) } } private class StructDecl extends Declaration, Struct { StructDecl() { this.isStruct() or this.isUnit() } - override TypeParam getATypeParam() { result = this.getGenericParamList().getATypeParam() } - override StructField getField(string name) { result = this.getStructField(name) } - override Type getDeclaredType(DeclarationPosition dpos, TypePath path) { - result = super.getDeclaredType(dpos, path) - or - // type of the struct itself - dpos.isStructPos() and - path.isEmpty() and - result = TStruct(this) - } + override TypeItem getTypeItem() { result = this } } private class StructVariantDecl extends Declaration, Variant { StructVariantDecl() { this.isStruct() or this.isUnit() } - Enum getEnum() { result.getVariantList().getAVariant() = this } - - override TypeParam getATypeParam() { - result = this.getEnum().getGenericParamList().getATypeParam() - } - override StructField getField(string name) { result = this.getStructField(name) } - override Type getDeclaredType(DeclarationPosition dpos, TypePath path) { - result = super.getDeclaredType(dpos, path) - or - // type of the enum itself - dpos.isStructPos() and - path.isEmpty() and - result = TEnum(this.getEnum()) - } + override TypeItem getTypeItem() { result = this.getEnum() } } class AccessPosition = DeclarationPosition; @@ -2841,11 +2824,21 @@ private module NonMethodResolution { } abstract private class TupleLikeConstructor extends Addressable { - abstract TypeParameter getTypeParameter(TypeParameterPosition ppos); + final TypeParameter getTypeParameter(TypeParameterPosition ppos) { + typeParamMatchPosition(this.getTypeItem().getGenericParamList().getATypeParam(), result, ppos) + } - abstract Type getParameterType(FunctionPosition pos, TypePath path); + abstract TypeItem getTypeItem(); - abstract Type getReturnType(TypePath path); + abstract TupleField getTupleField(int i); + + Type getReturnType(TypePath path) { + result = TDataType(this.getTypeItem()) and + path.isEmpty() + or + result = TTypeParamTypeParameter(this.getTypeItem().getGenericParamList().getATypeParam()) and + path = TypePath::singleton(result) + } Type getDeclaredType(FunctionPosition pos, TypePath path) { result = this.getParameterType(pos, path) @@ -2856,54 +2849,26 @@ abstract private class TupleLikeConstructor extends Addressable { pos.isSelf() and result = this.getReturnType(path) } -} - -private class TupleStruct extends TupleLikeConstructor, Struct { - TupleStruct() { this.isTuple() } - override TypeParameter getTypeParameter(TypeParameterPosition ppos) { - typeParamMatchPosition(this.getGenericParamList().getATypeParam(), result, ppos) + Type getParameterType(FunctionPosition pos, TypePath path) { + result = this.getTupleField(pos.asPosition()).getTypeRepr().(TypeMention).resolveTypeAt(path) } +} - override Type getParameterType(FunctionPosition pos, TypePath path) { - exists(int i | - result = this.getTupleField(i).getTypeRepr().(TypeMention).resolveTypeAt(path) and - i = pos.asPosition() - ) - } +private class TupleLikeStruct extends TupleLikeConstructor instanceof Struct { + TupleLikeStruct() { this.isTuple() } - override Type getReturnType(TypePath path) { - result = TStruct(this) and - path.isEmpty() - or - result = TTypeParamTypeParameter(this.getGenericParamList().getATypeParam()) and - path = TypePath::singleton(result) - } + override TypeItem getTypeItem() { result = this } + + override TupleField getTupleField(int i) { result = this.(Struct).getTupleField(i) } } -private class TupleVariant extends TupleLikeConstructor, Variant { - TupleVariant() { this.isTuple() } +private class TupleLikeVariant extends TupleLikeConstructor instanceof Variant { + TupleLikeVariant() { this.isTuple() } - override TypeParameter getTypeParameter(TypeParameterPosition ppos) { - typeParamMatchPosition(this.getEnum().getGenericParamList().getATypeParam(), result, ppos) - } + override TypeItem getTypeItem() { result = super.getEnum() } - override Type getParameterType(FunctionPosition pos, TypePath path) { - exists(int i | - result = this.getTupleField(i).getTypeRepr().(TypeMention).resolveTypeAt(path) and - i = pos.asPosition() - ) - } - - override Type getReturnType(TypePath path) { - exists(Enum enum | enum = this.getEnum() | - result = TEnum(enum) and - path.isEmpty() - or - result = TTypeParamTypeParameter(enum.getGenericParamList().getATypeParam()) and - path = TypePath::singleton(result) - ) - } + override TupleField getTupleField(int i) { result = this.(Variant).getTupleField(i) } } /** @@ -3224,7 +3189,7 @@ private module FieldExprMatchingInput implements MatchingInputSig { dpos.isSelf() and // no case for variants as those can only be destructured using pattern matching exists(Struct s | this.getAstNode() = [s.getStructField(_).(AstNode), s.getTupleField(_)] | - result = TStruct(s) and + result = TDataType(s) and path.isEmpty() or result = TTypeParamTypeParameter(s.getGenericParamList().getATypeParam()) and @@ -3374,15 +3339,15 @@ private Type inferTryExprType(TryExpr te, TypePath path) { } pragma[nomagic] -private StructType getStrStruct() { result = TStruct(any(Builtins::Str s)) } +private StructType getStrStruct() { result = TDataType(any(Builtins::Str s)) } pragma[nomagic] -private StructType getStringStruct() { result = TStruct(any(StringStruct s)) } +private StructType getStringStruct() { result = TDataType(any(StringStruct s)) } pragma[nomagic] private Type inferLiteralType(LiteralExpr le, TypePath path, boolean certain) { path.isEmpty() and - exists(Builtins::BuiltinType t | result = TStruct(t) | + exists(Builtins::BuiltinType t | result = TDataType(t) | le instanceof CharLiteralExpr and t instanceof Builtins::Char and certain = true @@ -3502,7 +3467,7 @@ private Type inferArrayExprType(ArrayExpr ae) { exists(ae) and result instanceof * Gets the root type of the range expression `re`. */ pragma[nomagic] -private Type inferRangeExprType(RangeExpr re) { result = TStruct(getRangeType(re)) } +private Type inferRangeExprType(RangeExpr re) { result = TDataType(getRangeType(re)) } /** * According to [the Rust reference][1]: _"array and slice-typed expressions @@ -3519,7 +3484,7 @@ private Type inferIndexExprType(IndexExpr ie, TypePath path) { // TODO: Method resolution to the `std::ops::Index` trait can handle the // `Index` instances for slices and arrays. exists(TypePath exprPath, Builtins::BuiltinType t | - TStruct(t) = inferType(ie.getIndex()) and + TDataType(t) = inferType(ie.getIndex()) and ( // also allow `i32`, since that is currently the type that we infer for // integer literals like `0` @@ -3879,11 +3844,11 @@ private module Cached { */ cached StructField resolveStructFieldExpr(FieldExpr fe, boolean isDereferenced) { - exists(string name, Type ty | + exists(string name, DataType ty | ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), isDereferenced) | - result = ty.(StructType).getStruct().getStructField(pragma[only_bind_into](name)) or - result = ty.(UnionType).getUnion().getStructField(pragma[only_bind_into](name)) + result = ty.(StructType).getTypeItem().getStructField(pragma[only_bind_into](name)) or + result = ty.(UnionType).getTypeItem().getStructField(pragma[only_bind_into](name)) ) } @@ -3896,7 +3861,7 @@ private module Cached { result = getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), isDereferenced) .(StructType) - .getStruct() + .getTypeItem() .getTupleField(pragma[only_bind_into](i)) ) } diff --git a/rust/ql/lib/codeql/rust/internal/TypeMention.qll b/rust/ql/lib/codeql/rust/internal/TypeMention.qll index 4da6a3aca346..d8cf06827f66 100644 --- a/rust/ql/lib/codeql/rust/internal/TypeMention.qll +++ b/rust/ql/lib/codeql/rust/internal/TypeMention.qll @@ -271,9 +271,7 @@ class NonAliasPathTypeMention extends PathTypeMention { pragma[nomagic] private Type resolveRootType() { - result = TStruct(resolved) - or - result = TEnum(resolved) + result = TDataType(resolved) or exists(TraitItemNode trait | trait = resolved | // If this is a `Self` path, then it resolves to the implicit `Self` @@ -283,8 +281,6 @@ class NonAliasPathTypeMention extends PathTypeMention { else result = TTrait(trait) ) or - result = TUnion(resolved) - or result = TTypeParamTypeParameter(resolved) or result = TAssociatedTypeTypeParameter(resolved) diff --git a/rust/ql/lib/codeql/rust/security/Barriers.qll b/rust/ql/lib/codeql/rust/security/Barriers.qll index 845a689af11a..a285bfe35694 100644 --- a/rust/ql/lib/codeql/rust/security/Barriers.qll +++ b/rust/ql/lib/codeql/rust/security/Barriers.qll @@ -14,7 +14,7 @@ private import codeql.rust.frameworks.stdlib.Builtins as Builtins /** A node whose type is a numeric type. */ class NumericTypeBarrier extends DataFlow::Node { NumericTypeBarrier() { - TypeInference::inferType(this.asExpr()).(StructType).getStruct() instanceof + TypeInference::inferType(this.asExpr()).(StructType).getTypeItem() instanceof Builtins::NumericType } } @@ -22,14 +22,14 @@ class NumericTypeBarrier extends DataFlow::Node { /** A node whose type is `bool`. */ class BooleanTypeBarrier extends DataFlow::Node { BooleanTypeBarrier() { - TypeInference::inferType(this.asExpr()).(StructType).getStruct() instanceof Builtins::Bool + TypeInference::inferType(this.asExpr()).(StructType).getTypeItem() instanceof Builtins::Bool } } /** A node whose type is an integral (integer). */ class IntegralTypeBarrier extends DataFlow::Node { IntegralTypeBarrier() { - TypeInference::inferType(this.asExpr()).(StructType).getStruct() instanceof + TypeInference::inferType(this.asExpr()).(StructType).getTypeItem() instanceof Builtins::IntegralType } } @@ -37,7 +37,7 @@ class IntegralTypeBarrier extends DataFlow::Node { /** A node whose type is a fieldless enum. */ class FieldlessEnumTypeBarrier extends DataFlow::Node { FieldlessEnumTypeBarrier() { - TypeInference::inferType(this.asExpr()).(EnumType).getEnum().isFieldless() + TypeInference::inferType(this.asExpr()).(EnumType).getTypeItem().isFieldless() } } From e0e493a9e394db76df291de54c6ad7761ac06e0d Mon Sep 17 00:00:00 2001 From: Simon Friis Vindum Date: Fri, 19 Dec 2025 13:53:38 +0100 Subject: [PATCH 2/2] Rust: Address review comments --- rust/ql/lib/codeql/rust/internal/Type.qll | 18 ++++++++++++------ .../lib/codeql/rust/internal/TypeInference.qll | 6 +++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/Type.qll b/rust/ql/lib/codeql/rust/internal/Type.qll index 50dd99fb73ae..9b409e20f76e 100644 --- a/rust/ql/lib/codeql/rust/internal/Type.qll +++ b/rust/ql/lib/codeql/rust/internal/Type.qll @@ -133,26 +133,32 @@ class DataType extends Type, TDataType { /** A struct type. */ class StructType extends DataType { - StructType() { super.getTypeItem() instanceof Struct } + private Struct struct; + + StructType() { struct = super.getTypeItem() } /** Gets the struct that this struct type represents. */ - override Struct getTypeItem() { result = super.getTypeItem() } + override Struct getTypeItem() { result = struct } } /** An enum type. */ class EnumType extends DataType { - EnumType() { super.getTypeItem() instanceof Enum } + private Enum enum; + + EnumType() { enum = super.getTypeItem() } /** Gets the enum that this enum type represents. */ - override Enum getTypeItem() { result = super.getTypeItem() } + override Enum getTypeItem() { result = enum } } /** A union type. */ class UnionType extends DataType { - UnionType() { super.getTypeItem() instanceof Union } + private Union union; + + UnionType() { union = super.getTypeItem() } /** Gets the union that this union type represents. */ - override Union getTypeItem() { result = super.getTypeItem() } + override Union getTypeItem() { result = union } } /** A trait type. */ diff --git a/rust/ql/lib/codeql/rust/internal/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/TypeInference.qll index b05e2921d3e6..c994cab6bb20 100644 --- a/rust/ql/lib/codeql/rust/internal/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/TypeInference.qll @@ -907,7 +907,7 @@ private module StructExprMatchingInput implements MatchingInputSig { result = this.getTypeParameter(_) and path = TypePath::singleton(result) or - // type of the struct itself + // type of the struct or enum itself dpos.isStructPos() and path.isEmpty() and result = TDataType(this.getTypeItem()) @@ -2860,7 +2860,7 @@ private class TupleLikeStruct extends TupleLikeConstructor instanceof Struct { override TypeItem getTypeItem() { result = this } - override TupleField getTupleField(int i) { result = this.(Struct).getTupleField(i) } + override TupleField getTupleField(int i) { result = Struct.super.getTupleField(i) } } private class TupleLikeVariant extends TupleLikeConstructor instanceof Variant { @@ -2868,7 +2868,7 @@ private class TupleLikeVariant extends TupleLikeConstructor instanceof Variant { override TypeItem getTypeItem() { result = super.getEnum() } - override TupleField getTupleField(int i) { result = this.(Variant).getTupleField(i) } + override TupleField getTupleField(int i) { result = Variant.super.getTupleField(i) } } /**